diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml deleted file mode 100644 index 97e02502..00000000 --- a/.github/workflows/lint.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: Lint - -on: - push: - branches: - - main - - prod - pull_request: - branches: - - main - - prod - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - name: Check out repository - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install dependencies - run: pip install black - - - name: Lint with black - run: black --check . diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..a768e53a --- /dev/null +++ b/Makefile @@ -0,0 +1,57 @@ +.PHONY: all format lint lint_diff format_diff lint_package lint_tests spell_check spell_fix help lint-fix + +# Define a variable for Python and notebook files. +PYTHON_FILES=src/ +MYPY_CACHE=.mypy_cache + +###################### +# LINTING AND FORMATTING +###################### + +lint format: PYTHON_FILES=. +lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') +lint_package: PYTHON_FILES=src +lint_tests: PYTHON_FILES=tests +lint_tests: MYPY_CACHE=.mypy_cache_test + +lint lint_diff lint_package lint_tests: + python -m ruff check . + [ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff + [ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I,F401 --fix $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + +format format_diff: + ruff format $(PYTHON_FILES) + ruff check --fix $(PYTHON_FILES) + +spell_check: + codespell --toml pyproject.toml + +spell_fix: + codespell --toml pyproject.toml -w + +###################### +# RUN ALL +###################### + +all: format lint spell_check + +###################### +# HELP +###################### + +help: + @echo '----' + @echo 'format - run code formatters' + @echo 'lint - run linters' + @echo 'spell_check - run spell check' + @echo 'all - run all tasks' + @echo 'lint-fix - run lint and fix issues' + +###################### +# LINT-FIX TARGET +###################### + +lint-fix: format lint + @echo "Linting and fixing completed successfully." diff --git a/README.md b/README.md index afe05936..7590ad5e 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,11 @@ ## Supported LLM's -- GPT-3.5-turbo-0125 -- CodeLlama-70B -- Mistral Medium +- GPT-4o +- Gemini Flash 1.5 8B +- Claude 3 Haiku +- Llama 3.2 3B +- Llama 3.1 405B # @@ -27,11 +29,12 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- ## 🌟 Features -- **Conversational AI**: Harnesses ChatGPT to translate natural language into precise SQL queries. +- **Conversational AI**: Use ChatGPT and other models to translate natural language into precise SQL queries. - **Conversational Memory**: Retains context for interactive, dynamic responses. - **Snowflake Integration**: Offers seamless, real-time data insights straight from your Snowflake database. - **Self-healing SQL**: Proactively suggests solutions for SQL errors, streamlining data access. - **Interactive User Interface**: Transforms data querying into an engaging conversation, complete with a chat reset option. +- **Agent-based Architecture**: Utilizes an agent to manage interactions and tool usage. ## 🛠️ Installation @@ -42,7 +45,9 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- cd snowchat pip install -r requirements.txt -3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY` and `REPLICATE_API_TOKEN` in project directory `secrets.toml`. +3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY`, `SUPABASE_STORAGE_URL`,`CLOUDFLARE_ACCOUNT_ID`, `CLOUDFLARE_NAMESPACE_ID`, + `CLOUDFLARE_API_TOKEN` in project directory `secrets.toml`. + Cloudflare is used here for caching Snowflake responses in KV. 4. Make you're schemas and store them in docs folder that matches you're database. @@ -53,12 +58,6 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- 7. Run the Streamlit app to start chatting: streamlit run main.py -## 🚀 Additional Enhancements - -1. **Platform Integration**: Connect snowChat with popular communication platforms like Slack or Discord for seamless interaction. -2. **Voice Integration**: Implement voice recognition and text-to-speech functionality to make the chatbot more interactive and user-friendly. -3. **Advanced Analytics**: Integrate with popular data visualization libraries like Plotly or Matplotlib to generate interactive visualizations based on the user's queries (AutoGPT). - ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=kaarthik108/snowChat&type=Date)] diff --git a/agent.py b/agent.py new file mode 100644 index 00000000..95a8c107 --- /dev/null +++ b/agent.py @@ -0,0 +1,115 @@ +import os +import streamlit as st +from dataclasses import dataclass +from typing import Annotated, Sequence, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import START, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition +from langgraph.graph.message import add_messages +from langchain_core.messages import BaseMessage + +from tools import retriever_tool +from tools import search, sql_executor_tool +from PIL import Image +from io import BytesIO + +@dataclass +class MessagesState: + messages: Annotated[Sequence[BaseMessage], add_messages] + + +memory = MemorySaver() + + +@dataclass +class ModelConfig: + model_name: str + api_key: str + base_url: Optional[str] = None + + +model_configurations = { + "gpt-4o": ModelConfig( + model_name="gpt-4o", api_key=st.secrets["OPENAI_API_KEY"] + ), + "Gemini Flash 1.5 8B": ModelConfig( + model_name="google/gemini-flash-1.5-8b", + api_key=st.secrets["OPENROUTER_API_KEY"], + base_url="https://openrouter.ai/api/v1", + ), + "claude3-haiku": ModelConfig( + model_name="claude-3-haiku-20240307", api_key=st.secrets["ANTHROPIC_API_KEY"] + ), + "llama-3.2-3b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p2-3b-instruct", + api_key=st.secrets["FIREWORKS_API_KEY"], + base_url="https://api.fireworks.ai/inference/v1", + ), + "llama-3.1-405b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", + api_key=st.secrets["FIREWORKS_API_KEY"], + base_url="https://api.fireworks.ai/inference/v1", + ), +} +sys_msg = SystemMessage( + content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools: + - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code. + - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code. + - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. You only have read access to the database, do not modify the database in any way. + """ +) +tools = [retriever_tool, search, sql_executor_tool] + +def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph: + config = model_configurations.get(model_name) + if not config: + raise ValueError(f"Unsupported model name: {model_name}") + + if not config.api_key: + raise ValueError(f"API key for model '{model_name}' is not set. Please check your environment variables or secrets configuration.") + + llm = ( + ChatOpenAI( + model=config.model_name, + api_key=config.api_key, + callbacks=[callback_handler], + streaming=True, + base_url=config.base_url, + temperature=0.01, + ) + if config.model_name != "claude-3-haiku-20240307" + else ChatAnthropic( + model=config.model_name, + api_key=config.api_key, + callbacks=[callback_handler], + streaming=True, + ) + ) + + llm_with_tools = llm.bind_tools(tools) + + def llm_agent(state: MessagesState): + return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]} + + builder = StateGraph(MessagesState) + builder.add_node("llm_agent", llm_agent) + builder.add_node("tools", ToolNode(tools)) + builder.add_edge(START, "llm_agent") + builder.add_conditional_edges("llm_agent", tools_condition) + builder.add_edge("tools", "llm_agent") + + react_graph = builder.compile(checkpointer=memory) + + # png_data = react_graph.get_graph(xray=True).draw_mermaid_png() + # with open("graph.png", "wb") as f: + # f.write(png_data) + + # image = Image.open(BytesIO(png_data)) + # st.image(image, caption="React Graph") + + return react_graph diff --git a/chain.py b/chain.py index a16b8627..a0c25952 100644 --- a/chain.py +++ b/chain.py @@ -1,155 +1,154 @@ -from typing import Any, Callable, Dict, Optional - -import streamlit as st -from langchain_community.chat_models import ChatOpenAI -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.llms import OpenAI -from langchain.vectorstores import SupabaseVectorStore -from pydantic import BaseModel, validator -from supabase.client import Client, create_client - -from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT - -from operator import itemgetter - -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import format_document -from langchain_core.messages import get_buffer_string -from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableParallel, RunnablePassthrough -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from langchain_anthropic import ChatAnthropic - -DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") - -supabase_url = st.secrets["SUPABASE_URL"] -supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] -supabase: Client = create_client(supabase_url, supabase_key) - - -class ModelConfig(BaseModel): - model_type: str - secrets: Dict[str, Any] - callback_handler: Optional[Callable] = None - - -class ModelWrapper: - def __init__(self, config: ModelConfig): - self.model_type = config.model_type - self.secrets = config.secrets - self.callback_handler = config.callback_handler - self.llm = self._setup_llm() - - def _setup_llm(self): - model_config = { - "gpt-4o-mini": { - "model_name": "gpt-4o-mini", - "api_key": self.secrets["OPENAI_API_KEY"], - }, - "gemma2-9b": { - "model_name": "gemma2-9b-it", - "api_key": self.secrets["GROQ_API_KEY"], - "base_url": "https://api.groq.com/openai/v1", - }, - "claude3-haiku": { - "model_name": "claude-3-haiku-20240307", - "api_key": self.secrets["ANTHROPIC_API_KEY"], - }, - "mixtral-8x22b": { - "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct", - "api_key": self.secrets["FIREWORKS_API_KEY"], - "base_url": "https://api.fireworks.ai/inference/v1", - }, - "llama-3.1-405b": { - "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct", - "api_key": self.secrets["FIREWORKS_API_KEY"], - "base_url": "https://api.fireworks.ai/inference/v1", - }, - } - - config = model_config[self.model_type] - - return ( - ChatOpenAI( - model_name=config["model_name"], - temperature=0.1, - api_key=config["api_key"], - max_tokens=700, - callbacks=[self.callback_handler], - streaming=True, - base_url=config["base_url"] - if config["model_name"] != "gpt-4o-mini" - else None, - default_headers={ - "HTTP-Referer": "https://snowchat.streamlit.app/", - "X-Title": "Snowchat", - }, - ) - if config["model_name"] != "claude-3-haiku-20240307" - else ( - ChatAnthropic( - model=config["model_name"], - temperature=0.1, - max_tokens=700, - timeout=None, - max_retries=2, - callbacks=[self.callback_handler], - streaming=True, - ) - ) - ) - - def get_chain(self, vectorstore): - def _combine_documents( - docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" - ): - doc_strings = [format_document(doc, document_prompt) for doc in docs] - return document_separator.join(doc_strings) - - _inputs = RunnableParallel( - standalone_question=RunnablePassthrough.assign( - chat_history=lambda x: get_buffer_string(x["chat_history"]) - ) - | CONDENSE_QUESTION_PROMPT - | OpenAI() - | StrOutputParser(), - ) - _context = { - "context": itemgetter("standalone_question") - | vectorstore.as_retriever() - | _combine_documents, - "question": lambda x: x["standalone_question"], - } - conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm - - return conversational_qa_chain - - -def load_chain(model_name="qwen", callback_handler=None): - embeddings = OpenAIEmbeddings( - openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" - ) - vectorstore = SupabaseVectorStore( - embedding=embeddings, - client=supabase, - table_name="documents", - query_name="v_match_documents", - ) - - model_type_mapping = { - "gpt-4o-mini": "gpt-4o-mini", - "gemma2-9b": "gemma2-9b", - "claude3-haiku": "claude3-haiku", - "mixtral-8x22b": "mixtral-8x22b", - "llama-3.1-405b": "llama-3.1-405b", - } - - model_type = model_type_mapping.get(model_name.lower()) - if model_type is None: - raise ValueError(f"Unsupported model name: {model_name}") - - config = ModelConfig( - model_type=model_type, secrets=st.secrets, callback_handler=callback_handler - ) - model = ModelWrapper(config) - return model.get_chain(vectorstore) +# from dataclasses import dataclass, field +# from operator import itemgetter +# from typing import Any, Callable, Dict, Optional + +# import streamlit as st +# from langchain.embeddings.openai import OpenAIEmbeddings +# from langchain.llms import OpenAI +# from langchain.prompts.prompt import PromptTemplate +# from langchain.schema import format_document +# from langchain.vectorstores import SupabaseVectorStore +# from langchain_anthropic import ChatAnthropic +# from langchain_community.chat_models import ChatOpenAI +# from langchain_core.messages import get_buffer_string +# from langchain_core.output_parsers import StrOutputParser +# from langchain_core.runnables import RunnableParallel, RunnablePassthrough +# from langchain_openai import ChatOpenAI, OpenAIEmbeddings + +# from supabase.client import Client, create_client +# from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT + +# DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") + +# supabase_url = st.secrets["SUPABASE_URL"] +# supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] +# supabase: Client = create_client(supabase_url, supabase_key) + + +# @dataclass +# class ModelConfig: +# model_type: str +# secrets: Dict[str, Any] +# callback_handler: Optional[Callable] = field(default=None) + + +# class ModelWrapper: +# def __init__(self, config: ModelConfig): +# self.model_type = config.model_type +# self.secrets = config.secrets +# self.callback_handler = config.callback_handler +# self.llm = self._setup_llm() + +# def _setup_llm(self): +# model_config = { +# "gpt-4o-mini": { +# "model_name": "gpt-4o-mini", +# "api_key": self.secrets["OPENAI_API_KEY"], +# }, +# "gemma2-9b": { +# "model_name": "gemma2-9b-it", +# "api_key": self.secrets["GROQ_API_KEY"], +# "base_url": "https://api.groq.com/openai/v1", +# }, +# "claude3-haiku": { +# "model_name": "claude-3-haiku-20240307", +# "api_key": self.secrets["ANTHROPIC_API_KEY"], +# }, +# "mixtral-8x22b": { +# "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct", +# "api_key": self.secrets["FIREWORKS_API_KEY"], +# "base_url": "https://api.fireworks.ai/inference/v1", +# }, +# "llama-3.1-405b": { +# "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct", +# "api_key": self.secrets["FIREWORKS_API_KEY"], +# "base_url": "https://api.fireworks.ai/inference/v1", +# }, +# } + +# config = model_config[self.model_type] + +# return ( +# ChatOpenAI( +# model_name=config["model_name"], +# temperature=0.1, +# api_key=config["api_key"], +# max_tokens=700, +# callbacks=[self.callback_handler], +# streaming=True, +# base_url=config["base_url"] +# if config["model_name"] != "gpt-4o-mini" +# else None, +# default_headers={ +# "HTTP-Referer": "https://snowchat.streamlit.app/", +# "X-Title": "Snowchat", +# }, +# ) +# if config["model_name"] != "claude-3-haiku-20240307" +# else ( +# ChatAnthropic( +# model=config["model_name"], +# temperature=0.1, +# max_tokens=700, +# timeout=None, +# max_retries=2, +# callbacks=[self.callback_handler], +# streaming=True, +# ) +# ) +# ) + +# def get_chain(self, vectorstore): +# def _combine_documents( +# docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" +# ): +# doc_strings = [format_document(doc, document_prompt) for doc in docs] +# return document_separator.join(doc_strings) + +# _inputs = RunnableParallel( +# standalone_question=RunnablePassthrough.assign( +# chat_history=lambda x: get_buffer_string(x["chat_history"]) +# ) +# | CONDENSE_QUESTION_PROMPT +# | OpenAI() +# | StrOutputParser(), +# ) +# _context = { +# "context": itemgetter("standalone_question") +# | vectorstore.as_retriever() +# | _combine_documents, +# "question": lambda x: x["standalone_question"], +# } +# conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm + +# return conversational_qa_chain + + +# def load_chain(model_name="qwen", callback_handler=None): +# embeddings = OpenAIEmbeddings( +# openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" +# ) +# vectorstore = SupabaseVectorStore( +# embedding=embeddings, +# client=supabase, +# table_name="documents", +# query_name="v_match_documents", +# ) + +# model_type_mapping = { +# "gpt-4o-mini": "gpt-4o-mini", +# "gemma2-9b": "gemma2-9b", +# "claude3-haiku": "claude3-haiku", +# "mixtral-8x22b": "mixtral-8x22b", +# "llama-3.1-405b": "llama-3.1-405b", +# } + +# model_type = model_type_mapping.get(model_name.lower()) +# if model_type is None: +# raise ValueError(f"Unsupported model name: {model_name}") + +# config = ModelConfig( +# model_type=model_type, secrets=st.secrets, callback_handler=callback_handler +# ) +# model = ModelWrapper(config) +# return model.get_chain(vectorstore) diff --git a/graph.png b/graph.png new file mode 100644 index 00000000..dcb88cd0 Binary files /dev/null and b/graph.png differ diff --git a/ingest.py b/ingest.py index 67de0f33..c6669f35 100644 --- a/ingest.py +++ b/ingest.py @@ -6,6 +6,7 @@ from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import SupabaseVectorStore from pydantic import BaseModel + from supabase.client import Client, create_client diff --git a/main.py b/main.py index 1a491be1..ddbe5d19 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,10 @@ import warnings import streamlit as st +from langchain_core.messages import HumanMessage from snowflake.snowpark.exceptions import SnowparkSQLException -from chain import load_chain +from agent import MessagesState, create_agent # from utils.snow_connect import SnowflakeConnection from utils.snowchat_ui import StreamlitUICallbackHandler, message_func @@ -34,11 +35,11 @@ st.caption("Talk your way through data") model_options = { - "gpt-4o-mini": "GPT-4o Mini", + "gpt-4o": "GPT-4o", "llama-3.1-405b": "Llama 3.1 405B", - "gemma2-9b": "Gemma 2 9B", + "llama-3.2-3b": "Llama 3.2 3B", "claude3-haiku": "Claude 3 Haiku", - "mixtral-8x22b": "Mixtral 8x22B", + "Gemini Flash 1.5 8B": "Gemini Flash 1.5 8B", } model = st.radio( @@ -50,6 +51,9 @@ ) st.session_state["model"] = model +if "assistant_response_processed" not in st.session_state: + st.session_state["assistant_response_processed"] = True # Initialize to True + if "toast_shown" not in st.session_state: st.session_state["toast_shown"] = False @@ -76,6 +80,7 @@ "content": "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍", }, ] +config = {"configurable": {"thread_id": "42"}} with open("ui/sidebar.md", "r") as sidebar_file: sidebar_content = sidebar_file.read() @@ -118,18 +123,28 @@ # Prompt for user input and save if prompt := st.chat_input(): st.session_state.messages.append({"role": "user", "content": prompt}) + st.session_state["assistant_response_processed"] = ( + False # Assistant response not yet processed + ) -for message in st.session_state.messages: +messages_to_display = st.session_state.messages.copy() +# if not st.session_state["assistant_response_processed"]: +# # Exclude the last assistant message if assistant response not yet processed +# if messages_to_display and messages_to_display[-1]["role"] == "assistant": +# print("\n\nthis is messages_to_display \n\n", messages_to_display) +# messages_to_display = messages_to_display[:-1] + +for message in messages_to_display: message_func( message["content"], - True if message["role"] == "user" else False, - True if message["role"] == "data" else False, - model, + is_user=(message["role"] == "user"), + is_df=(message["role"] == "data"), + model=model, ) callback_handler = StreamlitUICallbackHandler(model) -chain = load_chain(st.session_state["model"], callback_handler) +react_graph = create_agent(callback_handler, st.session_state["model"]) def append_chat_history(question, answer): @@ -148,20 +163,21 @@ def append_message(content, role="assistant"): def handle_sql_exception(query, conn, e, retries=2): - append_message("Uh oh, I made an error, let me try to fix it..") - error_message = ( - "You gave me a wrong SQL. FIX The SQL query by searching the schema definition: \n```sql\n" - + query - + "\n```\n Error message: \n " - + str(e) - ) - new_query = chain({"question": error_message, "chat_history": ""})["answer"] - append_message(new_query) - if get_sql(new_query) and retries > 0: - return execute_sql(get_sql(new_query), conn, retries - 1) - else: - append_message("I'm sorry, I couldn't fix the error. Please try again.") - return None + # append_message("Uh oh, I made an error, let me try to fix it..") + # error_message = ( + # "You gave me a wrong SQL. FIX The SQL query by searching the schema definition: \n```sql\n" + # + query + # + "\n```\n Error message: \n " + # + str(e) + # ) + # new_query = chain({"question": error_message, "chat_history": ""})["answer"] + # append_message(new_query) + # if get_sql(new_query) and retries > 0: + # return execute_sql(get_sql(new_query), conn, retries - 1) + # else: + # append_message("I'm sorry, I couldn't fix the error. Please try again.") + # return None + pass def execute_sql(query, conn, retries=2): @@ -176,20 +192,25 @@ def execute_sql(query, conn, retries=2): if ( "messages" in st.session_state - and st.session_state["messages"][-1]["role"] != "assistant" + and st.session_state["messages"][-1]["role"] == "user" + and not st.session_state["assistant_response_processed"] ): user_input_content = st.session_state["messages"][-1]["content"] if isinstance(user_input_content, str): + # Start loading animation callback_handler.start_loading_message() - result = chain.invoke( - { - "question": user_input_content, - "chat_history": [h for h in st.session_state["history"]], - } - ) - append_message(result.content) + messages = [HumanMessage(content=user_input_content)] + + state = MessagesState(messages=messages) + result = react_graph.invoke(state, config=config, debug=True) + + if result["messages"]: + assistant_message = callback_handler.final_message + append_message(assistant_message) + st.session_state["assistant_response_processed"] = True + if ( st.session_state["model"] == "Mixtral 8x7B" diff --git a/requirements.txt b/requirements.txt index 98406b92..3a356826 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,14 @@ -langchain==0.2.12 -pandas==1.5.0 -pydantic==1.10.8 +langchain==0.3.3 +langchain_anthropic==0.2.3 +langchain_community==0.3.2 +langchain_core==0.3.12 +langchain_openai==0.2.2 +langchain-google-genai==2.0.1 +langgraph==0.2.38 +pydantic==2.9.2 +Requests==2.32.3 +snowflake_connector_python==3.1.0 snowflake_snowpark_python==1.5.0 -snowflake-snowpark-python[pandas] -streamlit==1.31.0 -supabase==2.4.1 -unstructured -tiktoken -openai -black -langchain_openai -langchain-community -langchain-core -langchain-anthropic \ No newline at end of file +streamlit==1.33.0 +websocket_client==1.7.0 +duckduckgo_search==6.3.0 \ No newline at end of file diff --git a/template.py b/template.py index c8cd086c..5cc1759a 100644 --- a/template.py +++ b/template.py @@ -1,4 +1,3 @@ -from langchain.prompts.prompt import PromptTemplate from langchain_core.prompts import ChatPromptTemplate template = """You are an AI chatbot having a conversation with a human. diff --git a/tools.py b/tools.py new file mode 100644 index 00000000..fa89599e --- /dev/null +++ b/tools.py @@ -0,0 +1,39 @@ +import streamlit as st +from supabase.client import Client, create_client +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import SupabaseVectorStore +from langchain.tools.retriever import create_retriever_tool +from langchain_community.tools import DuckDuckGoSearchRun +from utils.snow_connect import SnowflakeConnection + +supabase_url = st.secrets["SUPABASE_URL"] +supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] +supabase: Client = create_client(supabase_url, supabase_key) + +embeddings = OpenAIEmbeddings( + openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" +) +vectorstore = SupabaseVectorStore( + embedding=embeddings, + client=supabase, + table_name="documents", + query_name="v_match_documents", +) + +retriever_tool = create_retriever_tool( + vectorstore.as_retriever(), + name="Database_Schema", + description="Search for database schema details", +) + +search = DuckDuckGoSearchRun() + +def sql_executor_tool(query: str, use_cache: bool = True) -> str: + """ + Execute snowflake sql queries with optional caching. + """ + conn = SnowflakeConnection() + return conn.execute_query(query, use_cache) + +# if __name__ == "__main__": +# print(sql_executor_tool("select * from STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS")) diff --git a/utils/snow_connect.py b/utils/snow_connect.py index 2268c8bd..f525adc3 100644 --- a/utils/snow_connect.py +++ b/utils/snow_connect.py @@ -1,13 +1,13 @@ from typing import Any, Dict - +import json +import requests import streamlit as st from snowflake.snowpark.session import Session -from snowflake.snowpark.version import VERSION class SnowflakeConnection: """ - This class is used to establish a connection to Snowflake. + This class is used to establish a connection to Snowflake and execute queries with optional caching. Attributes ---------- @@ -20,16 +20,24 @@ class SnowflakeConnection: ------- get_session() Establishes and returns the Snowflake connection session. - + execute_query(query: str, use_cache: bool = True) + Executes a Snowflake SQL query with optional caching. """ def __init__(self): self.connection_parameters = self._get_connection_parameters_from_env() self.session = None + self.cloudflare_account_id = st.secrets["CLOUDFLARE_ACCOUNT_ID"] + self.cloudflare_namespace_id = st.secrets["CLOUDFLARE_NAMESPACE_ID"] + self.cloudflare_api_token = st.secrets["CLOUDFLARE_API_TOKEN"] + self.headers = { + "Authorization": f"Bearer {self.cloudflare_api_token}", + "Content-Type": "application/json" + } @staticmethod def _get_connection_parameters_from_env() -> Dict[str, Any]: - connection_parameters = { + return { "account": st.secrets["ACCOUNT"], "user": st.secrets["USER_NAME"], "password": st.secrets["PASSWORD"], @@ -38,7 +46,6 @@ def _get_connection_parameters_from_env() -> Dict[str, Any]: "schema": st.secrets["SCHEMA"], "role": st.secrets["ROLE"], } - return connection_parameters def get_session(self): """ @@ -50,3 +57,45 @@ def get_session(self): self.session = Session.builder.configs(self.connection_parameters).create() self.session.sql_simplifier_enabled = True return self.session + + def _construct_kv_url(self, key: str) -> str: + return f"https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/storage/kv/namespaces/{self.cloudflare_namespace_id}/values/{key}" + + def get_from_cache(self, key: str) -> str: + url = self._construct_kv_url(key) + try: + response = requests.get(url, headers=self.headers) + response.raise_for_status() + print("\n\n\nCache hit\n\n\n") + return response.text + except requests.exceptions.RequestException as e: + print(f"Cache miss or error: {e}") + return None + + def set_to_cache(self, key: str, value: str) -> None: + url = self._construct_kv_url(key) + serialized_value = json.dumps(value) + try: + response = requests.put(url, headers=self.headers, data=serialized_value) + response.raise_for_status() + print("Cache set successfully") + except requests.exceptions.RequestException as e: + print(f"Failed to set cache: {e}") + + def execute_query(self, query: str, use_cache: bool = True) -> str: + """ + Execute a Snowflake SQL query with optional caching. + """ + if use_cache: + cached_response = self.get_from_cache(query) + if cached_response: + return json.loads(cached_response) + + session = self.get_session() + result = session.sql(query).collect() + result_list = [row.as_dict() for row in result] + + if use_cache: + self.set_to_cache(query, result_list) + + return result_list diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 03f5f58f..05db4b1e 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -4,7 +4,6 @@ import streamlit as st from langchain.callbacks.base import BaseCallbackHandler - image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/" gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z" mistral_url = ( @@ -35,6 +34,8 @@ def get_model_url(model_name): return snow_url elif "gpt" in model_name.lower(): return openai_url + elif "gemini" in model_name.lower(): + return gemini_url return mistral_url @@ -61,7 +62,7 @@ def format_message(text): def message_func(text, is_user=False, is_df=False, model="gpt"): """ - This function is used to display the messages in the chatbot UI. + This function displays messages in the chatbot UI, ensuring proper alignment and avatar positioning. Parameters: text (str): The text to be displayed. @@ -69,52 +70,36 @@ def message_func(text, is_user=False, is_df=False, model="gpt"): is_df (bool): Whether the message is a dataframe or not. """ model_url = get_model_url(model) - - avatar_url = model_url - if is_user: - avatar_url = user_url - message_alignment = "flex-end" - message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" - avatar_class = "user-avatar" - st.write( - f""" -