From b1d1715b0510985976e41b0ceb913aec842830ac Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Tue, 19 Dec 2023 09:55:32 +1300 Subject: [PATCH] Add Mixtral --- README.md | 12 +++--------- chain.py | 50 ++++++++++++++++++++++++++---------------------- main.py | 2 +- requirements.txt | 2 +- 4 files changed, 32 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 3e409fe..accad71 100644 --- a/README.md +++ b/README.md @@ -9,26 +9,22 @@ [![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://snowchat.streamlit.app/) - ![156shots_so](https://github.com/kaarthik108/snowChat/assets/53030784/7538d25b-a2d4-4a2c-9601-fb4c7db3c0b6) - **snowChat** is an intuitive and user-friendly application that allows users to interact with their Snowflake data using natural language queries. Type in your questions or requests, and SnowChat will generate the appropriate SQL query and return the data you need. No more complex SQL queries or digging through tables - SnowChat makes it easy to access your data! By bringing data one step closer, SnowChat empowers users to make data-driven decisions faster and more efficiently, reducing the barriers between users and the insights they seek. ## Supported LLM's + - GPT-3.5-turbo-16k -- Code-llama-13B-instruct - Claude-instant-v1 +- Mixtral 8x7B # - https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6-d8157dd1580a - # - ## 🌟 Features - **Conversational AI**: Harnesses ChatGPT to translate natural language into precise SQL queries. @@ -37,7 +33,6 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- - **Self-healing SQL**: Proactively suggests solutions for SQL errors, streamlining data access. - **Interactive User Interface**: Transforms data querying into an engaging conversation, complete with a chat reset option. - ## 🛠️ Installation 1. Clone this repository: @@ -47,7 +42,7 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- cd snowchat pip install -r requirements.txt -3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY` and `REPLICATE_API_TOKEN` in project directory `secrets.toml`. +3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY` and `REPLICATE_API_TOKEN` in project directory `secrets.toml`. 4. Make you're schemas and store them in docs folder that matches you're database. @@ -68,7 +63,6 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6- [![Star History Chart](https://api.star-history.com/svg?repos=kaarthik108/snowChat&type=Date)] - ## 🤝 Contributing Feel free to contribute to this project by submitting a pull request or opening an issue. Your feedback and suggestions are greatly appreciated! diff --git a/chain.py b/chain.py index 4c65956..025a016 100644 --- a/chain.py +++ b/chain.py @@ -6,7 +6,7 @@ from langchain.chains.question_answering import load_qa_chain from langchain.chat_models import ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.llms import OpenAI, Replicate +from langchain.llms import OpenAI from langchain.llms.bedrock import Bedrock from langchain.vectorstores import SupabaseVectorStore from pydantic import BaseModel, validator @@ -18,9 +18,6 @@ supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] supabase: Client = create_client(supabase_url, supabase_key) -VERSION = "1f01a52ff933873dff339d5fb5e1fd6f24f77456836f514fa05e91c1a42699c7" -LLAMA = "meta/codellama-13b-instruct:{}".format(VERSION) - class ModelConfig(BaseModel): model_type: str @@ -29,7 +26,7 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["code-llama", "gpt", "claude"]: + if v not in ["gpt", "claude", "mixtral"]: raise ValueError(f"Unsupported model type: {v}") return v @@ -42,26 +39,12 @@ def __init__(self, config: ModelConfig): self.setup() def setup(self): - if self.model_type == "code-llama": - self.setup_llama() - elif self.model_type == "gpt": + if self.model_type == "gpt": self.setup_gpt() elif self.model_type == "claude": self.setup_claude() - - def setup_llama(self): - self.q_llm = Replicate( - model=LLAMA, - input={"temperature": 0.2, "max_length": 200, "top_p": 1}, - replicate_api_token=self.secrets["REPLICATE_API_TOKEN"], - ) - self.llm = Replicate( - streaming=True, - callbacks=[self.callback_handler], - model=LLAMA, - input={"temperature": 0.2, "max_length": 300, "top_p": 1}, - replicate_api_token=self.secrets["REPLICATE_API_TOKEN"], - ) + elif self.model_type == "mixtral": + self.setup_mixtral() def setup_gpt(self): self.q_llm = OpenAI( @@ -80,6 +63,25 @@ def setup_gpt(self): streaming=True, ) + def setup_mixtral(self): + self.q_llm = OpenAI( + temperature=0.1, + openai_api_key=self.secrets["MIXTRAL_API_KEY"], + model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", + max_tokens=500, + base_url="https://api.together.xyz/v1", + ) + + self.llm = ChatOpenAI( + model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", + temperature=0.5, + openai_api_key=self.secrets["MIXTRAL_API_KEY"], + max_tokens=500, + callbacks=[self.callback_handler], + streaming=True, + base_url="https://api.together.xyz/v1", + ) + def setup_claude(self): bedrock_runtime = boto3.client( service_name="bedrock-runtime", @@ -132,8 +134,10 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): model_type = "claude" elif "GPT-3.5" in model_name: model_type = "gpt" + elif "mixtral" in model_name.lower(): + model_type = "mixtral" else: - model_type = "code-llama" + raise ValueError(f"Unsupported model name: {model_name}") config = ModelConfig( model_type=model_type, secrets=st.secrets, callback_handler=callback_handler diff --git a/main.py b/main.py index 020f810..6e96d9e 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["✨ GPT-3.5", "🐐 code-LLama", "♾️ Claude"], + options=["✨ GPT-3.5", "♾️ Claude", "⛰️ Mixtral"], index=0, horizontal=True, ) diff --git a/requirements.txt b/requirements.txt index 94f9151..4aa7a89 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -langchain==0.0.305 +langchain==0.0.350 pandas==1.5.0 pydantic==1.10.8 snowflake_snowpark_python==1.5.0