From 99a179c75326992758b94ff35ce01f497fc85cc4 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Sun, 13 Oct 2024 22:15:33 +1300 Subject: [PATCH 1/9] Use agents --- .github/workflows/lint.yml | 27 +++- Makefile | 57 +++++++ agent.py | 102 ++++++++++++ chain.py | 309 ++++++++++++++++++------------------- ingest.py | 1 + main.py | 77 +++++---- template.py | 1 - tools.py | 28 ++++ utils/snow_connect.py | 1 - utils/snowchat_ui.py | 91 +++++------ 10 files changed, 452 insertions(+), 242 deletions(-) create mode 100644 Makefile create mode 100644 agent.py create mode 100644 tools.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 97e02502..6c7256e3 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,26 +4,39 @@ on: push: branches: - main - - prod pull_request: branches: - main - - prod jobs: lint: + name: Lint and Format Code runs-on: ubuntu-latest + steps: - name: Check out repository uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 + with: + python-version: "3.9" + + - name: Cache pip dependencies + uses: actions/cache@v3 with: - python-version: 3.9 + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- - name: Install dependencies - run: pip install black + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install black ruff mypy codespell - - name: Lint with black - run: black --check . + - name: Run Formatting and Linting + run: | + make format + make lint diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..a768e53a --- /dev/null +++ b/Makefile @@ -0,0 +1,57 @@ +.PHONY: all format lint lint_diff format_diff lint_package lint_tests spell_check spell_fix help lint-fix + +# Define a variable for Python and notebook files. +PYTHON_FILES=src/ +MYPY_CACHE=.mypy_cache + +###################### +# LINTING AND FORMATTING +###################### + +lint format: PYTHON_FILES=. +lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') +lint_package: PYTHON_FILES=src +lint_tests: PYTHON_FILES=tests +lint_tests: MYPY_CACHE=.mypy_cache_test + +lint lint_diff lint_package lint_tests: + python -m ruff check . + [ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff + [ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I,F401 --fix $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + +format format_diff: + ruff format $(PYTHON_FILES) + ruff check --fix $(PYTHON_FILES) + +spell_check: + codespell --toml pyproject.toml + +spell_fix: + codespell --toml pyproject.toml -w + +###################### +# RUN ALL +###################### + +all: format lint spell_check + +###################### +# HELP +###################### + +help: + @echo '----' + @echo 'format - run code formatters' + @echo 'lint - run linters' + @echo 'spell_check - run spell check' + @echo 'all - run all tasks' + @echo 'lint-fix - run lint and fix issues' + +###################### +# LINT-FIX TARGET +###################### + +lint-fix: format lint + @echo "Linting and fixing completed successfully." diff --git a/agent.py b/agent.py new file mode 100644 index 00000000..a7a7462b --- /dev/null +++ b/agent.py @@ -0,0 +1,102 @@ +import os +from dataclasses import dataclass +from typing import Annotated, Sequence, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import START, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition +from langgraph.graph.message import add_messages +from langchain_core.messages import BaseMessage + +from template import TEMPLATE +from tools import retriever_tool + + +@dataclass +class MessagesState: + messages: Annotated[Sequence[BaseMessage], add_messages] + + +memory = MemorySaver() + + +@dataclass +class ModelConfig: + model_name: str + api_key: str + base_url: Optional[str] = None + + +def create_agent(callback_handler: BaseCallbackHandler, model_name: str): + model_configurations = { + "gpt-4o-mini": ModelConfig( + model_name="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY") + ), + "gemma2-9b": ModelConfig( + model_name="gemma2-9b-it", + api_key=os.getenv("GROQ_API_KEY"), + base_url="https://api.groq.com/openai/v1", + ), + "claude3-haiku": ModelConfig( + model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") + ), + "mixtral-8x22b": ModelConfig( + model_name="accounts/fireworks/models/mixtral-8x22b-instruct", + api_key=os.getenv("FIREWORKS_API_KEY"), + base_url="https://api.fireworks.ai/inference/v1", + ), + "llama-3.1-405b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", + api_key=os.getenv("FIREWORKS_API_KEY"), + base_url="https://api.fireworks.ai/inference/v1", + ), + } + config = model_configurations.get(model_name) + if not config: + raise ValueError(f"Unsupported model name: {model_name}") + + sys_msg = SystemMessage( + content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. + Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code. + """ + ) + + llm = ( + ChatOpenAI( + model=config.model_name, + api_key=config.api_key, + callbacks=[callback_handler], + streaming=True, + base_url=config.base_url, + ) + if config.model_name != "claude-3-haiku-20240307" + else ChatAnthropic( + model=config.model_name, + api_key=config.api_key, + callbacks=[callback_handler], + streaming=True, + ) + ) + + tools = [retriever_tool] + + llm_with_tools = llm.bind_tools(tools) + + def reasoner(state: MessagesState): + return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]} + + # Build the graph + builder = StateGraph(MessagesState) + builder.add_node("reasoner", reasoner) + builder.add_node("tools", ToolNode(tools)) + builder.add_edge(START, "reasoner") + builder.add_conditional_edges("reasoner", tools_condition) + builder.add_edge("tools", "reasoner") + + react_graph = builder.compile(checkpointer=memory) + + return react_graph diff --git a/chain.py b/chain.py index a16b8627..a0c25952 100644 --- a/chain.py +++ b/chain.py @@ -1,155 +1,154 @@ -from typing import Any, Callable, Dict, Optional - -import streamlit as st -from langchain_community.chat_models import ChatOpenAI -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.llms import OpenAI -from langchain.vectorstores import SupabaseVectorStore -from pydantic import BaseModel, validator -from supabase.client import Client, create_client - -from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT - -from operator import itemgetter - -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import format_document -from langchain_core.messages import get_buffer_string -from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableParallel, RunnablePassthrough -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from langchain_anthropic import ChatAnthropic - -DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") - -supabase_url = st.secrets["SUPABASE_URL"] -supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] -supabase: Client = create_client(supabase_url, supabase_key) - - -class ModelConfig(BaseModel): - model_type: str - secrets: Dict[str, Any] - callback_handler: Optional[Callable] = None - - -class ModelWrapper: - def __init__(self, config: ModelConfig): - self.model_type = config.model_type - self.secrets = config.secrets - self.callback_handler = config.callback_handler - self.llm = self._setup_llm() - - def _setup_llm(self): - model_config = { - "gpt-4o-mini": { - "model_name": "gpt-4o-mini", - "api_key": self.secrets["OPENAI_API_KEY"], - }, - "gemma2-9b": { - "model_name": "gemma2-9b-it", - "api_key": self.secrets["GROQ_API_KEY"], - "base_url": "https://api.groq.com/openai/v1", - }, - "claude3-haiku": { - "model_name": "claude-3-haiku-20240307", - "api_key": self.secrets["ANTHROPIC_API_KEY"], - }, - "mixtral-8x22b": { - "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct", - "api_key": self.secrets["FIREWORKS_API_KEY"], - "base_url": "https://api.fireworks.ai/inference/v1", - }, - "llama-3.1-405b": { - "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct", - "api_key": self.secrets["FIREWORKS_API_KEY"], - "base_url": "https://api.fireworks.ai/inference/v1", - }, - } - - config = model_config[self.model_type] - - return ( - ChatOpenAI( - model_name=config["model_name"], - temperature=0.1, - api_key=config["api_key"], - max_tokens=700, - callbacks=[self.callback_handler], - streaming=True, - base_url=config["base_url"] - if config["model_name"] != "gpt-4o-mini" - else None, - default_headers={ - "HTTP-Referer": "https://snowchat.streamlit.app/", - "X-Title": "Snowchat", - }, - ) - if config["model_name"] != "claude-3-haiku-20240307" - else ( - ChatAnthropic( - model=config["model_name"], - temperature=0.1, - max_tokens=700, - timeout=None, - max_retries=2, - callbacks=[self.callback_handler], - streaming=True, - ) - ) - ) - - def get_chain(self, vectorstore): - def _combine_documents( - docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" - ): - doc_strings = [format_document(doc, document_prompt) for doc in docs] - return document_separator.join(doc_strings) - - _inputs = RunnableParallel( - standalone_question=RunnablePassthrough.assign( - chat_history=lambda x: get_buffer_string(x["chat_history"]) - ) - | CONDENSE_QUESTION_PROMPT - | OpenAI() - | StrOutputParser(), - ) - _context = { - "context": itemgetter("standalone_question") - | vectorstore.as_retriever() - | _combine_documents, - "question": lambda x: x["standalone_question"], - } - conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm - - return conversational_qa_chain - - -def load_chain(model_name="qwen", callback_handler=None): - embeddings = OpenAIEmbeddings( - openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" - ) - vectorstore = SupabaseVectorStore( - embedding=embeddings, - client=supabase, - table_name="documents", - query_name="v_match_documents", - ) - - model_type_mapping = { - "gpt-4o-mini": "gpt-4o-mini", - "gemma2-9b": "gemma2-9b", - "claude3-haiku": "claude3-haiku", - "mixtral-8x22b": "mixtral-8x22b", - "llama-3.1-405b": "llama-3.1-405b", - } - - model_type = model_type_mapping.get(model_name.lower()) - if model_type is None: - raise ValueError(f"Unsupported model name: {model_name}") - - config = ModelConfig( - model_type=model_type, secrets=st.secrets, callback_handler=callback_handler - ) - model = ModelWrapper(config) - return model.get_chain(vectorstore) +# from dataclasses import dataclass, field +# from operator import itemgetter +# from typing import Any, Callable, Dict, Optional + +# import streamlit as st +# from langchain.embeddings.openai import OpenAIEmbeddings +# from langchain.llms import OpenAI +# from langchain.prompts.prompt import PromptTemplate +# from langchain.schema import format_document +# from langchain.vectorstores import SupabaseVectorStore +# from langchain_anthropic import ChatAnthropic +# from langchain_community.chat_models import ChatOpenAI +# from langchain_core.messages import get_buffer_string +# from langchain_core.output_parsers import StrOutputParser +# from langchain_core.runnables import RunnableParallel, RunnablePassthrough +# from langchain_openai import ChatOpenAI, OpenAIEmbeddings + +# from supabase.client import Client, create_client +# from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT + +# DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") + +# supabase_url = st.secrets["SUPABASE_URL"] +# supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] +# supabase: Client = create_client(supabase_url, supabase_key) + + +# @dataclass +# class ModelConfig: +# model_type: str +# secrets: Dict[str, Any] +# callback_handler: Optional[Callable] = field(default=None) + + +# class ModelWrapper: +# def __init__(self, config: ModelConfig): +# self.model_type = config.model_type +# self.secrets = config.secrets +# self.callback_handler = config.callback_handler +# self.llm = self._setup_llm() + +# def _setup_llm(self): +# model_config = { +# "gpt-4o-mini": { +# "model_name": "gpt-4o-mini", +# "api_key": self.secrets["OPENAI_API_KEY"], +# }, +# "gemma2-9b": { +# "model_name": "gemma2-9b-it", +# "api_key": self.secrets["GROQ_API_KEY"], +# "base_url": "https://api.groq.com/openai/v1", +# }, +# "claude3-haiku": { +# "model_name": "claude-3-haiku-20240307", +# "api_key": self.secrets["ANTHROPIC_API_KEY"], +# }, +# "mixtral-8x22b": { +# "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct", +# "api_key": self.secrets["FIREWORKS_API_KEY"], +# "base_url": "https://api.fireworks.ai/inference/v1", +# }, +# "llama-3.1-405b": { +# "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct", +# "api_key": self.secrets["FIREWORKS_API_KEY"], +# "base_url": "https://api.fireworks.ai/inference/v1", +# }, +# } + +# config = model_config[self.model_type] + +# return ( +# ChatOpenAI( +# model_name=config["model_name"], +# temperature=0.1, +# api_key=config["api_key"], +# max_tokens=700, +# callbacks=[self.callback_handler], +# streaming=True, +# base_url=config["base_url"] +# if config["model_name"] != "gpt-4o-mini" +# else None, +# default_headers={ +# "HTTP-Referer": "https://snowchat.streamlit.app/", +# "X-Title": "Snowchat", +# }, +# ) +# if config["model_name"] != "claude-3-haiku-20240307" +# else ( +# ChatAnthropic( +# model=config["model_name"], +# temperature=0.1, +# max_tokens=700, +# timeout=None, +# max_retries=2, +# callbacks=[self.callback_handler], +# streaming=True, +# ) +# ) +# ) + +# def get_chain(self, vectorstore): +# def _combine_documents( +# docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" +# ): +# doc_strings = [format_document(doc, document_prompt) for doc in docs] +# return document_separator.join(doc_strings) + +# _inputs = RunnableParallel( +# standalone_question=RunnablePassthrough.assign( +# chat_history=lambda x: get_buffer_string(x["chat_history"]) +# ) +# | CONDENSE_QUESTION_PROMPT +# | OpenAI() +# | StrOutputParser(), +# ) +# _context = { +# "context": itemgetter("standalone_question") +# | vectorstore.as_retriever() +# | _combine_documents, +# "question": lambda x: x["standalone_question"], +# } +# conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm + +# return conversational_qa_chain + + +# def load_chain(model_name="qwen", callback_handler=None): +# embeddings = OpenAIEmbeddings( +# openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" +# ) +# vectorstore = SupabaseVectorStore( +# embedding=embeddings, +# client=supabase, +# table_name="documents", +# query_name="v_match_documents", +# ) + +# model_type_mapping = { +# "gpt-4o-mini": "gpt-4o-mini", +# "gemma2-9b": "gemma2-9b", +# "claude3-haiku": "claude3-haiku", +# "mixtral-8x22b": "mixtral-8x22b", +# "llama-3.1-405b": "llama-3.1-405b", +# } + +# model_type = model_type_mapping.get(model_name.lower()) +# if model_type is None: +# raise ValueError(f"Unsupported model name: {model_name}") + +# config = ModelConfig( +# model_type=model_type, secrets=st.secrets, callback_handler=callback_handler +# ) +# model = ModelWrapper(config) +# return model.get_chain(vectorstore) diff --git a/ingest.py b/ingest.py index 67de0f33..c6669f35 100644 --- a/ingest.py +++ b/ingest.py @@ -6,6 +6,7 @@ from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import SupabaseVectorStore from pydantic import BaseModel + from supabase.client import Client, create_client diff --git a/main.py b/main.py index 1a491be1..97a6ddce 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,10 @@ import warnings import streamlit as st +from langchain_core.messages import HumanMessage from snowflake.snowpark.exceptions import SnowparkSQLException -from chain import load_chain +from agent import MessagesState, create_agent # from utils.snow_connect import SnowflakeConnection from utils.snowchat_ui import StreamlitUICallbackHandler, message_func @@ -50,6 +51,9 @@ ) st.session_state["model"] = model +if "assistant_response_processed" not in st.session_state: + st.session_state["assistant_response_processed"] = True # Initialize to True + if "toast_shown" not in st.session_state: st.session_state["toast_shown"] = False @@ -76,6 +80,7 @@ "content": "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍", }, ] +config = {"configurable": {"thread_id": "42"}} with open("ui/sidebar.md", "r") as sidebar_file: sidebar_content = sidebar_file.read() @@ -118,18 +123,28 @@ # Prompt for user input and save if prompt := st.chat_input(): st.session_state.messages.append({"role": "user", "content": prompt}) + st.session_state["assistant_response_processed"] = ( + False # Assistant response not yet processed + ) -for message in st.session_state.messages: +messages_to_display = st.session_state.messages.copy() +# if not st.session_state["assistant_response_processed"]: +# # Exclude the last assistant message if assistant response not yet processed +# if messages_to_display and messages_to_display[-1]["role"] == "assistant": +# print("\n\nthis is messages_to_display \n\n", messages_to_display) +# messages_to_display = messages_to_display[:-1] + +for message in messages_to_display: message_func( message["content"], - True if message["role"] == "user" else False, - True if message["role"] == "data" else False, - model, + is_user=(message["role"] == "user"), + is_df=(message["role"] == "data"), + model=model, ) callback_handler = StreamlitUICallbackHandler(model) -chain = load_chain(st.session_state["model"], callback_handler) +react_graph = create_agent(callback_handler, st.session_state["model"]) def append_chat_history(question, answer): @@ -148,20 +163,21 @@ def append_message(content, role="assistant"): def handle_sql_exception(query, conn, e, retries=2): - append_message("Uh oh, I made an error, let me try to fix it..") - error_message = ( - "You gave me a wrong SQL. FIX The SQL query by searching the schema definition: \n```sql\n" - + query - + "\n```\n Error message: \n " - + str(e) - ) - new_query = chain({"question": error_message, "chat_history": ""})["answer"] - append_message(new_query) - if get_sql(new_query) and retries > 0: - return execute_sql(get_sql(new_query), conn, retries - 1) - else: - append_message("I'm sorry, I couldn't fix the error. Please try again.") - return None + # append_message("Uh oh, I made an error, let me try to fix it..") + # error_message = ( + # "You gave me a wrong SQL. FIX The SQL query by searching the schema definition: \n```sql\n" + # + query + # + "\n```\n Error message: \n " + # + str(e) + # ) + # new_query = chain({"question": error_message, "chat_history": ""})["answer"] + # append_message(new_query) + # if get_sql(new_query) and retries > 0: + # return execute_sql(get_sql(new_query), conn, retries - 1) + # else: + # append_message("I'm sorry, I couldn't fix the error. Please try again.") + # return None + pass def execute_sql(query, conn, retries=2): @@ -176,20 +192,25 @@ def execute_sql(query, conn, retries=2): if ( "messages" in st.session_state - and st.session_state["messages"][-1]["role"] != "assistant" + and st.session_state["messages"][-1]["role"] == "user" + and not st.session_state["assistant_response_processed"] ): user_input_content = st.session_state["messages"][-1]["content"] if isinstance(user_input_content, str): + # Start loading animation callback_handler.start_loading_message() - result = chain.invoke( - { - "question": user_input_content, - "chat_history": [h for h in st.session_state["history"]], - } - ) - append_message(result.content) + messages = [HumanMessage(content=user_input_content)] + + state = MessagesState(messages=messages) + result = react_graph.invoke(state, config=config) + + if result["messages"]: + assistant_message = callback_handler.final_message + append_message(assistant_message) + st.session_state["assistant_response_processed"] = True + if ( st.session_state["model"] == "Mixtral 8x7B" diff --git a/template.py b/template.py index c8cd086c..5cc1759a 100644 --- a/template.py +++ b/template.py @@ -1,4 +1,3 @@ -from langchain.prompts.prompt import PromptTemplate from langchain_core.prompts import ChatPromptTemplate template = """You are an AI chatbot having a conversation with a human. diff --git a/tools.py b/tools.py new file mode 100644 index 00000000..5b5a4504 --- /dev/null +++ b/tools.py @@ -0,0 +1,28 @@ +import streamlit as st +from langchain.prompts.prompt import PromptTemplate +from supabase.client import Client, create_client +from langchain.tools.retriever import create_retriever_tool +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import SupabaseVectorStore + +supabase_url = st.secrets["SUPABASE_URL"] +supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] +supabase: Client = create_client(supabase_url, supabase_key) + + +embeddings = OpenAIEmbeddings( + openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" +) +vectorstore = SupabaseVectorStore( + embedding=embeddings, + client=supabase, + table_name="documents", + query_name="v_match_documents", +) + + +retriever_tool = create_retriever_tool( + vectorstore.as_retriever(), + name="Database_Schema", + description="Search for database schema details", +) diff --git a/utils/snow_connect.py b/utils/snow_connect.py index 2268c8bd..d0b396a6 100644 --- a/utils/snow_connect.py +++ b/utils/snow_connect.py @@ -2,7 +2,6 @@ import streamlit as st from snowflake.snowpark.session import Session -from snowflake.snowpark.version import VERSION class SnowflakeConnection: diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 03f5f58f..98a63370 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -1,10 +1,10 @@ import html import re +import textwrap import streamlit as st from langchain.callbacks.base import BaseCallbackHandler - image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/" gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z" mistral_url = ( @@ -61,7 +61,7 @@ def format_message(text): def message_func(text, is_user=False, is_df=False, model="gpt"): """ - This function is used to display the messages in the chatbot UI. + This function displays messages in the chatbot UI, ensuring proper alignment and avatar positioning. Parameters: text (str): The text to be displayed. @@ -69,52 +69,36 @@ def message_func(text, is_user=False, is_df=False, model="gpt"): is_df (bool): Whether the message is a dataframe or not. """ model_url = get_model_url(model) + avatar_url = user_url if is_user else model_url + message_bg_color = ( + "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" if is_user else "#71797E" + ) + avatar_class = "user-avatar" if is_user else "bot-avatar" + alignment = "flex-end" if is_user else "flex-start" + margin_side = "margin-left" if is_user else "margin-right" + message_text = html.escape(text.strip()).replace('\n', '
') - avatar_url = model_url if is_user: - avatar_url = user_url - message_alignment = "flex-end" - message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" - avatar_class = "user-avatar" - st.write( - f""" -
-
- {text} \n
- avatar -
- """, - unsafe_allow_html=True, - ) + container_html = f""" +
+
+ {message_text} +
+ avatar +
+ """ else: - message_alignment = "flex-start" - message_bg_color = "#71797E" - avatar_class = "bot-avatar" + container_html = f""" +
+ avatar +
+ {message_text} +
+
+ """ - if is_df: - st.write( - f""" -
- avatar -
- """, - unsafe_allow_html=True, - ) - st.write(text) - return - else: - text = format_message(text) + st.write(container_html, unsafe_allow_html=True) - st.write( - f""" -
- avatar -
- {text} \n
-
- """, - unsafe_allow_html=True, - ) class StreamlitUICallbackHandler(BaseCallbackHandler): @@ -125,6 +109,7 @@ def __init__(self, model): self.has_streaming_started = False self.model = model self.avatar_url = get_model_url(model) + self.final_message = "" def start_loading_message(self): loading_message_content = self._get_bot_message_container("Thinking...") @@ -138,6 +123,7 @@ def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs): complete_message = "".join(self.token_buffer) container_content = self._get_bot_message_container(complete_message) self.placeholder.markdown(container_content, unsafe_allow_html=True) + self.final_message = "".join(self.token_buffer) def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs): self.token_buffer = [] @@ -146,16 +132,20 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs): def _get_bot_message_container(self, text): """Generate the bot's message container style for the given text.""" - formatted_text = format_message(text) + formatted_text = format_message(text.strip()) container_content = f""" -
- avatar -
- {formatted_text} \n
+
+ avatar +
+ {formatted_text}
+
""" return container_content + + + def display_dataframe(self, df): """ Display the dataframe in Streamlit UI within the chat container. @@ -165,13 +155,14 @@ def display_dataframe(self, df): st.write( f""" -
- avatar +
+ avatar
""", unsafe_allow_html=True, ) st.write(df) + def __call__(self, *args, **kwargs): pass From 28cbf6af920ecff52e2730d0214d212d431bc54d Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Sun, 13 Oct 2024 22:45:52 +1300 Subject: [PATCH 2/9] use llama3.2 --- agent.py | 8 ++++---- main.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/agent.py b/agent.py index a7a7462b..a13e8bd3 100644 --- a/agent.py +++ b/agent.py @@ -33,8 +33,8 @@ class ModelConfig: def create_agent(callback_handler: BaseCallbackHandler, model_name: str): model_configurations = { - "gpt-4o-mini": ModelConfig( - model_name="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY") + "gpt-4o": ModelConfig( + model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY") ), "gemma2-9b": ModelConfig( model_name="gemma2-9b-it", @@ -44,8 +44,8 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str): "claude3-haiku": ModelConfig( model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") ), - "mixtral-8x22b": ModelConfig( - model_name="accounts/fireworks/models/mixtral-8x22b-instruct", + "llama-3.2-3b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p2-3b-instruct", api_key=os.getenv("FIREWORKS_API_KEY"), base_url="https://api.fireworks.ai/inference/v1", ), diff --git a/main.py b/main.py index 97a6ddce..fe540794 100644 --- a/main.py +++ b/main.py @@ -35,11 +35,11 @@ st.caption("Talk your way through data") model_options = { - "gpt-4o-mini": "GPT-4o Mini", + "gpt-4o": "GPT-4o", "llama-3.1-405b": "Llama 3.1 405B", "gemma2-9b": "Gemma 2 9B", "claude3-haiku": "Claude 3 Haiku", - "mixtral-8x22b": "Mixtral 8x22B", + "llama-3.2-3b": "Llama 3.2 3B", } model = st.radio( From 77fd93e907769f8297038ad71befb8472d5b164d Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Thu, 17 Oct 2024 20:34:43 +1300 Subject: [PATCH 3/9] Fix loading ui --- utils/snowchat_ui.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 98a63370..83399cc4 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -78,26 +78,26 @@ def message_func(text, is_user=False, is_df=False, model="gpt"): margin_side = "margin-left" if is_user else "margin-right" message_text = html.escape(text.strip()).replace('\n', '
') - if is_user: - container_html = f""" -
-
- {message_text} + if message_text: # Check if message_text is not empty + if is_user: + container_html = f""" +
+
+ {message_text} +
+ avatar
- avatar -
- """ - else: - container_html = f""" -
- avatar -
- {message_text} + """ + else: + container_html = f""" +
+ avatar +
+ {message_text} +
-
- """ - - st.write(container_html, unsafe_allow_html=True) + """ + st.write(container_html, unsafe_allow_html=True) @@ -133,6 +133,8 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs): def _get_bot_message_container(self, text): """Generate the bot's message container style for the given text.""" formatted_text = format_message(text.strip()) + if not formatted_text: # If no formatted text, show "Thinking..." + formatted_text = "Thinking..." container_content = f"""
avatar @@ -143,9 +145,6 @@ def _get_bot_message_container(self, text): """ return container_content - - - def display_dataframe(self, df): """ Display the dataframe in Streamlit UI within the chat container. From 3c6472c1ecae27f4948b251b0b41cba9d0a62fee Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Thu, 17 Oct 2024 21:14:30 +1300 Subject: [PATCH 4/9] update models --- agent.py | 9 +++++---- main.py | 4 ++-- utils/snowchat_ui.py | 2 ++ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/agent.py b/agent.py index a13e8bd3..d3324fdc 100644 --- a/agent.py +++ b/agent.py @@ -1,4 +1,5 @@ import os +import streamlit as st from dataclasses import dataclass from typing import Annotated, Sequence, Optional @@ -36,10 +37,10 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str): "gpt-4o": ModelConfig( model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY") ), - "gemma2-9b": ModelConfig( - model_name="gemma2-9b-it", - api_key=os.getenv("GROQ_API_KEY"), - base_url="https://api.groq.com/openai/v1", + "Gemini Flash 1.5 8B": ModelConfig( + model_name="google/gemini-flash-1.5-8b", + api_key=st.secrets["OPENROUTER_API_KEY"], + base_url="https://openrouter.ai/api/v1", ), "claude3-haiku": ModelConfig( model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") diff --git a/main.py b/main.py index fe540794..b75b3f56 100644 --- a/main.py +++ b/main.py @@ -37,9 +37,9 @@ model_options = { "gpt-4o": "GPT-4o", "llama-3.1-405b": "Llama 3.1 405B", - "gemma2-9b": "Gemma 2 9B", - "claude3-haiku": "Claude 3 Haiku", "llama-3.2-3b": "Llama 3.2 3B", + "claude3-haiku": "Claude 3 Haiku", + "Gemini Flash 1.5 8B": "Gemini Flash 1.5 8B", } model = st.radio( diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 83399cc4..2fe60d0e 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -35,6 +35,8 @@ def get_model_url(model_name): return snow_url elif "gpt" in model_name.lower(): return openai_url + elif "gemini" in model_name.lower(): + return gemini_url return mistral_url From fde0064fd6c76506569ca09d21343e51563add01 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Fri, 18 Oct 2024 19:10:31 +1300 Subject: [PATCH 5/9] Add Cloudflare KV caching --- agent.py | 88 +++++++++++++++++++++++------------------- graph.png | Bin 0 -> 8039 bytes main.py | 2 +- tools.py | 19 +++++++-- utils/snow_connect.py | 60 +++++++++++++++++++++++++--- 5 files changed, 120 insertions(+), 49 deletions(-) create mode 100644 graph.png diff --git a/agent.py b/agent.py index d3324fdc..fa0f2a69 100644 --- a/agent.py +++ b/agent.py @@ -13,9 +13,10 @@ from langgraph.graph.message import add_messages from langchain_core.messages import BaseMessage -from template import TEMPLATE from tools import retriever_tool - +from tools import search, sql_executor_tool +from PIL import Image +from io import BytesIO @dataclass class MessagesState: @@ -32,39 +33,43 @@ class ModelConfig: base_url: Optional[str] = None -def create_agent(callback_handler: BaseCallbackHandler, model_name: str): - model_configurations = { - "gpt-4o": ModelConfig( - model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY") - ), - "Gemini Flash 1.5 8B": ModelConfig( - model_name="google/gemini-flash-1.5-8b", - api_key=st.secrets["OPENROUTER_API_KEY"], - base_url="https://openrouter.ai/api/v1", - ), - "claude3-haiku": ModelConfig( - model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") - ), - "llama-3.2-3b": ModelConfig( - model_name="accounts/fireworks/models/llama-v3p2-3b-instruct", - api_key=os.getenv("FIREWORKS_API_KEY"), - base_url="https://api.fireworks.ai/inference/v1", - ), - "llama-3.1-405b": ModelConfig( - model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", - api_key=os.getenv("FIREWORKS_API_KEY"), - base_url="https://api.fireworks.ai/inference/v1", - ), - } +model_configurations = { + "gpt-4o": ModelConfig( + model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY") + ), + "Gemini Flash 1.5 8B": ModelConfig( + model_name="google/gemini-flash-1.5-8b", + api_key=st.secrets["OPENROUTER_API_KEY"], + base_url="https://openrouter.ai/api/v1", + ), + "claude3-haiku": ModelConfig( + model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") + ), + "llama-3.2-3b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p2-3b-instruct", + api_key=os.getenv("FIREWORKS_API_KEY"), + base_url="https://api.fireworks.ai/inference/v1", + ), + "llama-3.1-405b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", + api_key=os.getenv("FIREWORKS_API_KEY"), + base_url="https://api.fireworks.ai/inference/v1", + ), +} +sys_msg = SystemMessage( + content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools: + - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code. + - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code. + - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. + """ +) +tools = [retriever_tool, search, sql_executor_tool] + +def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph: config = model_configurations.get(model_name) if not config: raise ValueError(f"Unsupported model name: {model_name}") - sys_msg = SystemMessage( - content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. - Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code. - """ - ) llm = ( ChatOpenAI( @@ -73,6 +78,7 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str): callbacks=[callback_handler], streaming=True, base_url=config.base_url, + temperature=0.01, ) if config.model_name != "claude-3-haiku-20240307" else ChatAnthropic( @@ -83,21 +89,25 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str): ) ) - tools = [retriever_tool] - llm_with_tools = llm.bind_tools(tools) - def reasoner(state: MessagesState): + def llm_agent(state: MessagesState): return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]} - # Build the graph builder = StateGraph(MessagesState) - builder.add_node("reasoner", reasoner) + builder.add_node("llm_agent", llm_agent) builder.add_node("tools", ToolNode(tools)) - builder.add_edge(START, "reasoner") - builder.add_conditional_edges("reasoner", tools_condition) - builder.add_edge("tools", "reasoner") + builder.add_edge(START, "llm_agent") + builder.add_conditional_edges("llm_agent", tools_condition) + builder.add_edge("tools", "llm_agent") react_graph = builder.compile(checkpointer=memory) + # png_data = react_graph.get_graph(xray=True).draw_mermaid_png() + # with open("graph.png", "wb") as f: + # f.write(png_data) + + # image = Image.open(BytesIO(png_data)) + # st.image(image, caption="React Graph") + return react_graph diff --git a/graph.png b/graph.png new file mode 100644 index 0000000000000000000000000000000000000000..dcb88cd09d3a1a3d86f8a5b8262ebe6adc87849f GIT binary patch literal 8039 zcmb7I1wd5W);>r}2$C|CfFP;B(4~@!f`Zh5AkEM)^bjHvN(#~?paK%Z&?BMJ-ObP) zLwEk;z4txed;fcXtaHxV-`Q*JwZC&_pS{ ze*XX9^FLJ+o0-E*u?}0|dN@*l9tKVVZw=U@I|SRZLSTi0KG{gPjdNzLuGG_fTS_F@EJfCiupJoxqe*n6xv zWCDQrJ^DQ1t-7RsBC~SoI&cv0_b(SY8g;!xFFs%mG&5Az%-f0s>eh1l$D#0ny8GKpw!o za^=^F6+G;TPlS(;hlhWakdT0g^eQPS$yE{(GV<#bWaO0OBqS8H6qHodG&D4%*XZbJ zsp+p%(@_5kf`f}~gNILyk55cZMnXpY|CY-i03{J{3=HAoumM*nad0VdE*k(Q?CZt_ zaB+SO#lHlffDjjt2_S@F#>KsTqR*XN{NKk2i)lE!#+K4^Jdl4K9hdR*7Fa~nED$-cg^daN)!9Gw z2w;1_$0H!b8o`v<58?_AF&@b;UVjMA6-r#Z>o*1OKht>MOKqI-6Q7EPUD(9QKN>kB z#J074IR=p8Vrfz0QUY?oMQ$eB6*g8jwktDb&Na7>*VUQa{2u89MNU3v8b5>$FEyKd z0e>Bqj@z-bf+bx7XF2lzP{&Oennb_SirmSaYYZ*dzg68XAV6}Uh8{vxObLreO27~0 zsxqgjHPsu^OZh9RYDJ8F-80tsPWzkk?{rhkn5TPrrYV$SbINo&=eb;}PX9&({Jx4_KUs2=}Y z)Yc9785v=?u2I@;uXfpebZ%6h&*wqgdh@L6dR9jkl*`%QVhVgNz|H1wx#h-?B3!yL zPLoN|(Z8?KnIdPb4wJI1KT4un&Z8VfeX+WdIs><&U94zC-@md=< zHRBH6b~KYbjuEZ;Y|0=e0Fy2P_M#`^(DZK`zVHls=nh5oA9-xiTYq4eo(PpjN$;qr zZF-)dXP2l}ZmAS9xMkL8kld9@yOA(Tzp1K%W@MMCm&igk ztqWViId(MyWgk_Nc3A6rkgSMQRdcJe6}KNzSd?uql?qSp zMxn{#G?|U`Rwmr*Fh2P%X%I{mgHjilAbHxy)5vVEhc=HTW+R=4s){|(_6%NBOV7#q z=l0XGlQl`3Kt>B?1aewG9auC$c=aUYoFwlbGQKYtVn^P5`{{kj!jcRFu*{rL(QHSh zt~=Qm%x%(?kp*_eJk1U-Gv1O^8D}Tya1s{@L^#B$3(SWf;t|ze+1>;-OVcFBT1_+? zAWS-(b<$$vl-SMRvW0ghsU$R`I$y#P`c#Se9q|Mzh+WRnmh_eV(PiOC4bnJ$=GF2= z{|>vH(Fp%f$OhYkvNOy|C1Ha5M&@~kn-GK1wGTas;W9~0^GMIfpAmk#nFl>3q!pjy z#NmbfSPdD!KR(k_zU zK;kPZH1u*hxu-l?NDD)%H_rvs9LBN-ErY``;Invwv-piM6wBRlkNMP%FwB0*hl|1G z?CC+D?ZqluGX+DJ-$V2-gs+jZ-s1qa!_eY%3uHSWZawK*yrB*Qp`;_1C(4aSf(}x= z5N{wLqj}R-YnfYbx+iTP&_Y_Cz4OF-8O^{rtuCc`xQx}SuN&t9nWzZF}VY9H{NL|lukQ&l?1$#As@{b5sW>NFNfA}a!Y&nd@m7OPuFXB?#t_r%Z>b#a6QTXU+V z(zerrcU?gN$Q>V6-vzEO9=spH7KVIm4GCG(xw@vx>36NT0sF4rJ#VPaPPk<{B-kBUzXA$j$IFA{D^1y;gFc0 z-14q(?^2JpU+mP(sj92}B%k>Sh1exfV+yZ6^^Hg!KaR2@3V9V0@;|8(bP>w}?q_x3 z``bo>^O%84Aazc1v%(#-8#$W6cgSCHd>?9kJR~wUh z>)J5GX84q^=B875ecLs{4mzZ_RKDEZpC{-^c@V?$Tb78yT(Gq@IoTAFsM$uV4;TjAM{*BhlHX|JM-2bmek#0?$ zh1+_8sCA!5`XRe_L?P87*v#RN z^x-eXcdEi{Rf2bTWO!)QE7g2X`?_R;$f0|$zdw5Y zXIIN9i!!Kk0?u3%vHFVi0_JnIs8-{EI#cnMJKV@H@Wwq3%-NX^B~ zwW*j(;N8Wf4N|r`zM=M;{)CnEsg9H%Mr*i$Mq}w}X_~xDO2=s3Pf>XZ`JahmsIGaG zksXR7(pQ;+@U|1<50@$XQ0~yIJKb{vHuh3;;Er!GGVgve(cbx1PIv2;tY>2e$ZXLV zkAfKP78=gv8;DjD87z3(bi`a5Q3n1ZnVJn9^>mOr_Hma=7d>@(VRfYV#R{4# z7u0OZj?h=rJG@|!ut-wUDhBcDv)A60ye9<}vq`zXc;ZjL*oAa4d{C0VPdw~psSmA6 zih@2aFR`HPUv$NZt7-X! zZ&SR=P5jyuaD9V%2*1>$#t^4D($P7X;bLd#;q*beZ(0da{D>zLbDHeN!v5?D| zSihlKqQ9i|>%6aA;>j1$#srz_^ra^h2Y@$%eYW||`}UR1)hrkzeeY<0Hj@LzoIu*! zvhA4r_8#W1SsC!m*F9jP^I~q*l}z7W7k9pQ)0(__p=5RmW`z+bJGK4 zAp;MafY86p#w0Smmq5#+T=^Zp5U-wC&P6WXY-c+2X48p;3};Kd^~}Wgoum<~+Y6N< zom@GqL>z~Q2z&MCkguF}a1TD~;wRKxfhAM&LbYY8QpsJj$GU{-IEDpIx5qCkhL+?d}DR@A&J>@REisyGlL4~N@6Ut7{o}-n zooDYvIGU=a$U_L`)22U~`h=?+7b|lNASu**Rl+;lluVZ8D{diS(udBdt`G{i@V5VP z`b|rXw|OS!ajC+LLM0M(<{TOUT=9(T&Z;N|cTb1g7<&CK`2cfd1pepM$=|4azjO1#`^%_CzGKBNhVTh_;r{KzRS$WGrz zXSYodx+A=?5WkSmLD+zx{p-1+0NY;F;%rdtLpyoxz~y<^xOtygmKkFf(odWZh#rP1 zm!Q&HM)&S1>d~0@ipb{{>qC6=X*rXMFa-vts4Arpr485^4qV~hb-+|t6z6Lr-u7r< zjKVgDx)sb|a`u|j{g!`9Q9|DhTKap2w6Bo>!ixs9-HMdHXfH;z^ld> z6@Iu}V;rZAcoiOek9z0NWwpD~Q7vykGQhy0Q<;se*@UTfX7A-4Hm6~`&vAMwM90s{ zsg}nNbL~_c_zJ?g!xDREd8>uxq5`{YP;TUhAJmxoq$$V5nMOh7AXz0;{MY^2Ts+77 zpqoLQK|^g#xdA{Kxn7%S9%hwG8@kL0*WUQ37uLH!Po?yi&>#5d+pYb&{Sg zYV5_Cx3IKX1MK(zkRumQ@%8T|fw#QHfnrFV7l@e#GbmCKPnI)SE1 z*YbUAA*Y$VPGi__6qA4$WIM%ixeGQ{A9Uk*PigT-hDcth3ex9Sdo+}FUEvZynDj{o zgChzYr6^_wD2P_8POfpw+9*zMLi}tjyqsk9M8{LnS?+d==>73E1J&Tb^v{i7;0G)5p;OOeA=T)bywg^$j-l{YS^pw*6nTfV zB+};H?a+?5DQS?1>sbZ!_Bt6_-+AB~<)YX6mMCaU4DT*48$w`Fl@ZC!jFBwW;K1N{ z45Vw``(|v>+TpjIR@NO8&{@Jbx@nJzP_u|5X3z;yP`Ag-sJUsAeCTIy2#Ty14L@F~ zMckQI53&R8_WKwb{V1B(uL+KE((00=3&L-Rgnfw_lwG%V`-E^7u19}Jqi?HJWVLmp zclTzjf_d8UDkMMmMuw&zt8QIe;@YG3)i{$(lV;8mXc zo+h}rczY*5OMh@^-r?7|iEtS3qL}IN6u`q~BNc|AwV8bA6T@*m0>cEWG*RzF2>kQQ zU2xQWCJbvtmWc<-J#)1cy8q$H%68yzuCsz-or2-lvxE-cY_SoZnlInE1dYBb9$YC} zMYM)Buflv@=Z)Fyz-UfB8NC#Uxb`z05xwr7FR65&J;T&I0URgTwLtDRfg{EA?f8JV z3(Y?_l+)y>BbdLx$iqvsYFL4ZoQd%Km|rNT#E0fld95(gIk5gy_g_c*=umMAO;&0ne<6Y(%p1lN>=e zj4@wxj9s?Ow*1&PZh75)(v08~z(6|QQ$G-tIIWa3yC^RP52?62@Ip2fnZ5kHzbnpa z27?}cXucVne`~y%xdZnl4mxdrqS-O)h4s1Cmtx6_3*uPtpw!a=3%rg)G*X(NDm%W^ zM&YHBQ32@c{~63X zNxEk`%>yw^9EnEWrEa!1dfE-%`A%k?IE_C%4;DQ1Edu%v_P)IRVeByfsl%}O>FZC& z#1%`$2TBYpH2qZ^4aN5hi9V-W9FEZSD2#Y21tB#>1dq_o1Dl!(S2;K zgApE+@RXwq{%K$TQI>}$)nW6-wAa<{z`cMvUIwLHm!Z+%bG zl_i#(D_YN;9MseI^}3JTZ?4mg5584o;7JH7YwaR~a=aqcP){^*sTrV$qpn7p(eJ*g zu#06&P0V3}xB0P~xe6`>(NXw5bA+NxdB?@GvNvnW;EcVPZp)Ir^8@f^e_j}XGhCe- z3lAQn(yls&~ZgYrny5W~UT(J`vDE9-_9k7e2~C>*T%#;;KFvd!-#S0$+kKtWLwuh9EPiH)QUBR()? z)fr9{0d3g!i%!v_eCw{4cg84@cf9Rj25$T#dxWmt?`K2}r^w5KaN?jOiDB!blAExp z(Q>YFvmeqs_DpSkqCoiBh3n584`@E`(QagaoMW-E$kQY!!IOR=>+)N#6mL4PtH58h z^$*)cW_M@iwh7yb^)n5RpE1aWrs+eZ@WFku&`%k-QydWjuoP1$N6h5I-P}OM<~48a z@TR9fr|dOUD^Gfc5(l;_vO`eF((9u)h;4$0t}Y$D=x;vPcE4b1f#YMc>B!mGtV7SA z!8FErIo4%se(0HoD)c95_+KR7f&zZE?Cz&L6@i`F-E&!0Pjzjtao|{&md!7_Nfj%V z)mL>%=5fAlsuD&x(j-h;{FJETW?2>zjlOoeD0tDAsL(4;(+Me%`|^Cv>pmsj;?>YW zqnI0!1B<8lI;F$&i@f_PwLi2|I>pSK=)T^4bvzx-s-A%QKgA3^$?BaOEAjT0zsPPHRDd5c+R`;y5tE+>0)(pA%o&8t|Gu4e;`96Sb z>^D`ePVbHXfKL;Y)zpS z%(ry48eN`zA&hq}RieRr1~&bPd@uMql?T6T zkPN2pWRy^u8K3jB2I|JYVUX0JV0CqG*ChO~cc z_pu~&v0?iD>ZAL>VNLxZo$cS;ln&^0Rimogz%QfE56_aye!#kiyw-aMJk?mqPGYMZ z=kAk%l%N=ym<7yE;=E-`c-DMF{Un5#3pUsgQv?zz=Tjm9y()a!^t2%IO=>%@!bpAO zr*+Bn{KAu2^T`i+yxD`jwAoEM+8LB|WcVZry%4{rBeX&Z()LP@5(?#CmDzvn*-xvr zL1~{4Yf*Vw^jT}wUfr|omF7D1HVTye4I$6$G4yCym-cRr_7QJyy24ia#I`T>Z%6}G z_`^Z9YhQmFq2xnV=jYydr1=|aeWofkb?TP5Bh}+kK$d-1Ti=rDRnya=iDQcb4e`l| g(*ss}G0l)rEv-R=$=q)?l$0-jCvv}$+~w$h0rjWn;{X5v literal 0 HcmV?d00001 diff --git a/main.py b/main.py index b75b3f56..ddbe5d19 100644 --- a/main.py +++ b/main.py @@ -204,7 +204,7 @@ def execute_sql(query, conn, retries=2): messages = [HumanMessage(content=user_input_content)] state = MessagesState(messages=messages) - result = react_graph.invoke(state, config=config) + result = react_graph.invoke(state, config=config, debug=True) if result["messages"]: assistant_message = callback_handler.final_message diff --git a/tools.py b/tools.py index 5b5a4504..2be57ba5 100644 --- a/tools.py +++ b/tools.py @@ -1,15 +1,15 @@ import streamlit as st -from langchain.prompts.prompt import PromptTemplate from supabase.client import Client, create_client -from langchain.tools.retriever import create_retriever_tool from langchain_openai import OpenAIEmbeddings from langchain_community.vectorstores import SupabaseVectorStore +from langchain.tools.retriever import create_retriever_tool +from langchain_community.tools import DuckDuckGoSearchRun +from utils.snow_connect import SnowflakeConnection supabase_url = st.secrets["SUPABASE_URL"] supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] supabase: Client = create_client(supabase_url, supabase_key) - embeddings = OpenAIEmbeddings( openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" ) @@ -20,9 +20,20 @@ query_name="v_match_documents", ) - retriever_tool = create_retriever_tool( vectorstore.as_retriever(), name="Database_Schema", description="Search for database schema details", ) + +search = DuckDuckGoSearchRun() + +def sql_executor_tool(query: str, use_cache: bool = True) -> str: + """ + Execute snowflake sql queries with optional caching. + """ + conn = SnowflakeConnection() + return conn.execute_query(query, use_cache) + +if __name__ == "__main__": + print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS")) diff --git a/utils/snow_connect.py b/utils/snow_connect.py index d0b396a6..f525adc3 100644 --- a/utils/snow_connect.py +++ b/utils/snow_connect.py @@ -1,12 +1,13 @@ from typing import Any, Dict - +import json +import requests import streamlit as st from snowflake.snowpark.session import Session class SnowflakeConnection: """ - This class is used to establish a connection to Snowflake. + This class is used to establish a connection to Snowflake and execute queries with optional caching. Attributes ---------- @@ -19,16 +20,24 @@ class SnowflakeConnection: ------- get_session() Establishes and returns the Snowflake connection session. - + execute_query(query: str, use_cache: bool = True) + Executes a Snowflake SQL query with optional caching. """ def __init__(self): self.connection_parameters = self._get_connection_parameters_from_env() self.session = None + self.cloudflare_account_id = st.secrets["CLOUDFLARE_ACCOUNT_ID"] + self.cloudflare_namespace_id = st.secrets["CLOUDFLARE_NAMESPACE_ID"] + self.cloudflare_api_token = st.secrets["CLOUDFLARE_API_TOKEN"] + self.headers = { + "Authorization": f"Bearer {self.cloudflare_api_token}", + "Content-Type": "application/json" + } @staticmethod def _get_connection_parameters_from_env() -> Dict[str, Any]: - connection_parameters = { + return { "account": st.secrets["ACCOUNT"], "user": st.secrets["USER_NAME"], "password": st.secrets["PASSWORD"], @@ -37,7 +46,6 @@ def _get_connection_parameters_from_env() -> Dict[str, Any]: "schema": st.secrets["SCHEMA"], "role": st.secrets["ROLE"], } - return connection_parameters def get_session(self): """ @@ -49,3 +57,45 @@ def get_session(self): self.session = Session.builder.configs(self.connection_parameters).create() self.session.sql_simplifier_enabled = True return self.session + + def _construct_kv_url(self, key: str) -> str: + return f"https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/storage/kv/namespaces/{self.cloudflare_namespace_id}/values/{key}" + + def get_from_cache(self, key: str) -> str: + url = self._construct_kv_url(key) + try: + response = requests.get(url, headers=self.headers) + response.raise_for_status() + print("\n\n\nCache hit\n\n\n") + return response.text + except requests.exceptions.RequestException as e: + print(f"Cache miss or error: {e}") + return None + + def set_to_cache(self, key: str, value: str) -> None: + url = self._construct_kv_url(key) + serialized_value = json.dumps(value) + try: + response = requests.put(url, headers=self.headers, data=serialized_value) + response.raise_for_status() + print("Cache set successfully") + except requests.exceptions.RequestException as e: + print(f"Failed to set cache: {e}") + + def execute_query(self, query: str, use_cache: bool = True) -> str: + """ + Execute a Snowflake SQL query with optional caching. + """ + if use_cache: + cached_response = self.get_from_cache(query) + if cached_response: + return json.loads(cached_response) + + session = self.get_session() + result = session.sql(query).collect() + result_list = [row.as_dict() for row in result] + + if use_cache: + self.set_to_cache(query, result_list) + + return result_list From cc1bbd0d7e777169ba67d8ab20c4f5adb6a81531 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Fri, 18 Oct 2024 19:18:07 +1300 Subject: [PATCH 6/9] update reqs --- requirements.txt | 28 ++++++++++++++-------------- tools.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/requirements.txt b/requirements.txt index 98406b92..3398eaa4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ -langchain==0.2.12 -pandas==1.5.0 -pydantic==1.10.8 +langchain==0.3.3 +langchain_anthropic==0.2.3 +langchain_community==0.3.2 +langchain_core==0.3.12 +langchain_openai==0.2.2 +langchain-google-genai==2.0.1 +langgraph==0.2.38 +Pillow==11.0.0 +pydantic==2.9.2 +Requests==2.32.3 +snowflake_connector_python==3.1.0 snowflake_snowpark_python==1.5.0 -snowflake-snowpark-python[pandas] -streamlit==1.31.0 -supabase==2.4.1 -unstructured -tiktoken -openai -black -langchain_openai -langchain-community -langchain-core -langchain-anthropic \ No newline at end of file +streamlit==1.33.0 +websocket_client==1.7.0 +duckduckgo_search==6.3.0 \ No newline at end of file diff --git a/tools.py b/tools.py index 2be57ba5..fa89599e 100644 --- a/tools.py +++ b/tools.py @@ -35,5 +35,5 @@ def sql_executor_tool(query: str, use_cache: bool = True) -> str: conn = SnowflakeConnection() return conn.execute_query(query, use_cache) -if __name__ == "__main__": - print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS")) +# if __name__ == "__main__": +# print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS")) From 5b30c10fb57192e555b86969a0f3a2dba11b2c59 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Fri, 18 Oct 2024 19:28:17 +1300 Subject: [PATCH 7/9] update prompt --- agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent.py b/agent.py index fa0f2a69..2dc8ded9 100644 --- a/agent.py +++ b/agent.py @@ -60,7 +60,7 @@ class ModelConfig: content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools: - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code. - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code. - - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. + - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. You only have read access to the database, do not modify the database in any way. """ ) tools = [retriever_tool, search, sql_executor_tool] From c5b26abffedc69e9c314c19fb94620fe244f2f85 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Fri, 18 Oct 2024 19:28:53 +1300 Subject: [PATCH 8/9] remove pillow --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3398eaa4..3a356826 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ langchain_core==0.3.12 langchain_openai==0.2.2 langchain-google-genai==2.0.1 langgraph==0.2.38 -Pillow==11.0.0 pydantic==2.9.2 Requests==2.32.3 snowflake_connector_python==3.1.0 From d62fbb2bed99e58bd9a4e73a8cd94029397c724e Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Fri, 18 Oct 2024 19:38:34 +1300 Subject: [PATCH 9/9] add error messages --- .github/workflows/lint.yml | 42 -------------------------------------- README.md | 21 +++++++++---------- agent.py | 10 +++++---- utils/snowchat_ui.py | 1 - 4 files changed, 16 insertions(+), 58 deletions(-) delete mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml deleted file mode 100644 index 6c7256e3..00000000 --- a/.github/workflows/lint.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: Lint - -on: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - lint: - name: Lint and Format Code - runs-on: ubuntu-latest - - steps: - - name: Check out repository - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.9" - - - name: Cache pip dependencies - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} - restore-keys: | - ${{ runner.os }}-pip- - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install black ruff mypy codespell - - - name: Run Formatting and Linting - run: | - make format - make lint diff --git a/README.md b/README.md index afe05936..7590ad5e 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,11 @@ ## Supported LLM's -- GPT-3.5-turbo-0125 -- CodeLlama-70B -- Mistral Medium +- GPT-4o +- Gemini Flash 1.5 8B +- Claude 3 Haiku +- Llama 3.2 3B +- Llama 3.1 405B # @@ -27,11 +29,12 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- ## 🌟 Features -- **Conversational AI**: Harnesses ChatGPT to translate natural language into precise SQL queries. +- **Conversational AI**: Use ChatGPT and other models to translate natural language into precise SQL queries. - **Conversational Memory**: Retains context for interactive, dynamic responses. - **Snowflake Integration**: Offers seamless, real-time data insights straight from your Snowflake database. - **Self-healing SQL**: Proactively suggests solutions for SQL errors, streamlining data access. - **Interactive User Interface**: Transforms data querying into an engaging conversation, complete with a chat reset option. +- **Agent-based Architecture**: Utilizes an agent to manage interactions and tool usage. ## 🛠️ Installation @@ -42,7 +45,9 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- cd snowchat pip install -r requirements.txt -3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY` and `REPLICATE_API_TOKEN` in project directory `secrets.toml`. +3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY`, `SUPABASE_STORAGE_URL`,`CLOUDFLARE_ACCOUNT_ID`, `CLOUDFLARE_NAMESPACE_ID`, + `CLOUDFLARE_API_TOKEN` in project directory `secrets.toml`. + Cloudflare is used here for caching Snowflake responses in KV. 4. Make you're schemas and store them in docs folder that matches you're database. @@ -53,12 +58,6 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- 7. Run the Streamlit app to start chatting: streamlit run main.py -## 🚀 Additional Enhancements - -1. **Platform Integration**: Connect snowChat with popular communication platforms like Slack or Discord for seamless interaction. -2. **Voice Integration**: Implement voice recognition and text-to-speech functionality to make the chatbot more interactive and user-friendly. -3. **Advanced Analytics**: Integrate with popular data visualization libraries like Plotly or Matplotlib to generate interactive visualizations based on the user's queries (AutoGPT). - ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=kaarthik108/snowChat&type=Date)] diff --git a/agent.py b/agent.py index 2dc8ded9..95a8c107 100644 --- a/agent.py +++ b/agent.py @@ -35,7 +35,7 @@ class ModelConfig: model_configurations = { "gpt-4o": ModelConfig( - model_name="gpt-4o", api_key=os.getenv("OPENAI_API_KEY") + model_name="gpt-4o", api_key=st.secrets["OPENAI_API_KEY"] ), "Gemini Flash 1.5 8B": ModelConfig( model_name="google/gemini-flash-1.5-8b", @@ -43,16 +43,16 @@ class ModelConfig: base_url="https://openrouter.ai/api/v1", ), "claude3-haiku": ModelConfig( - model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") + model_name="claude-3-haiku-20240307", api_key=st.secrets["ANTHROPIC_API_KEY"] ), "llama-3.2-3b": ModelConfig( model_name="accounts/fireworks/models/llama-v3p2-3b-instruct", - api_key=os.getenv("FIREWORKS_API_KEY"), + api_key=st.secrets["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1", ), "llama-3.1-405b": ModelConfig( model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", - api_key=os.getenv("FIREWORKS_API_KEY"), + api_key=st.secrets["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1", ), } @@ -70,6 +70,8 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> Stat if not config: raise ValueError(f"Unsupported model name: {model_name}") + if not config.api_key: + raise ValueError(f"API key for model '{model_name}' is not set. Please check your environment variables or secrets configuration.") llm = ( ChatOpenAI( diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 2fe60d0e..05db4b1e 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -1,6 +1,5 @@ import html import re -import textwrap import streamlit as st from langchain.callbacks.base import BaseCallbackHandler