diff --git a/chain.py b/chain.py index edd9b8b..2bff972 100644 --- a/chain.py +++ b/chain.py @@ -33,7 +33,7 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["gpt", "mistral", "gemini"]: + if v not in ["gpt", "mixtral8x22b", "claude", "mixtral8x7b"]: raise ValueError(f"Unsupported model type: {v}") return v @@ -52,23 +52,26 @@ def __init__(self, config: ModelConfig): def setup(self): if self.model_type == "gpt": self.setup_gpt() - elif self.model_type == "gemini": - self.setup_gemini() - elif self.model_type == "mistral": - self.setup_mixtral() + elif self.model_type == "claude": + self.setup_claude() + elif self.model_type == "mixtral8x7b": + self.setup_mixtral_8x7b() + elif self.model_type == "mixtral8x22b": + self.setup_mixtral_8x22b() + def setup_gpt(self): self.llm = ChatOpenAI( - model_name="gpt-3.5-turbo-0125", + model_name="gpt-3.5-turbo", temperature=0.2, api_key=self.secrets["OPENAI_API_KEY"], max_tokens=1000, callbacks=[self.callback_handler], streaming=True, - base_url=self.gateway_url, + # base_url=self.gateway_url, ) - def setup_mixtral(self): + def setup_mixtral_8x7b(self): self.llm = ChatOpenAI( model_name="mixtral-8x7b-32768", temperature=0.2, @@ -79,12 +82,27 @@ def setup_mixtral(self): base_url="https://api.groq.com/openai/v1", ) - def setup_gemini(self): + def setup_claude(self): self.llm = ChatOpenAI( - model_name="google/gemini-pro", - temperature=0.2, + model_name="anthropic/claude-3-haiku", + temperature=0.1, + api_key=self.secrets["OPENROUTER_API_KEY"], + max_tokens=700, + callbacks=[self.callback_handler], + streaming=True, + base_url="https://openrouter.ai/api/v1", + default_headers={ + "HTTP-Referer": "https://snowchat.streamlit.app/", + "X-Title": "Snowchat", + }, + ) + + def setup_mixtral_8x22b(self): + self.llm = ChatOpenAI( + model_name="mistralai/mixtral-8x22b", + temperature=0.1, api_key=self.secrets["OPENROUTER_API_KEY"], - max_tokens=1200, + max_tokens=700, callbacks=[self.callback_handler], streaming=True, base_url="https://openrouter.ai/api/v1", @@ -133,10 +151,12 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): if "GPT-3.5" in model_name: model_type = "gpt" - elif "mistral" in model_name.lower(): - model_type = "mistral" - elif "gemini" in model_name.lower(): - model_type = "gemini" + elif "mixtral 8x7b" in model_name.lower(): + model_type = "mixtral8x7b" + elif "claude" in model_name.lower(): + model_type = "claude" + elif "mixtral 8x22b" in model_name.lower(): + model_type = "mixtral8x22b" else: raise ValueError(f"Unsupported model name: {model_name}") diff --git a/main.py b/main.py index 894db25..644ddc0 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,7 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["GPT-3.5 - OpenAI", "Gemini 1.5 - Openrouter", "Mistral 8x7B - Groq"], + options=["Claude-3 Haiku", "Mixtral 8x7B", "Mixtral 8x22B", "GPT-3.5"], index=0, horizontal=True, ) @@ -43,12 +43,20 @@ if "toast_shown" not in st.session_state: st.session_state["toast_shown"] = False +if "rate-limit" not in st.session_state: + st.session_state["rate-limit"] = False + # Show the toast only if it hasn't been shown before if not st.session_state["toast_shown"]: st.toast("The snowflake data retrieval is disabled for now.", icon="👋") st.session_state["toast_shown"] = True -if st.session_state["model"] == "👑 Mistral 8x7B - Groq": +# Show a warning if the model is rate-limited +if st.session_state['rate-limit']: + st.toast("Probably rate limited.. Go easy folks", icon="⚠️") + st.session_state['rate-limit'] = False + +if st.session_state["model"] == "Mixtral 8x7B": st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️") INITIAL_MESSAGE = [ @@ -173,6 +181,9 @@ def execute_sql(query, conn, retries=2): ) append_message(result.content) +if st.session_state["model"] == "Mixtral 8x7B" and st.session_state['messages'][-1]['content'] == "": + st.session_state['rate-limit'] = True + # if get_sql(result): # conn = SnowflakeConnection().get_session() # df = execute_sql(get_sql(result), conn) diff --git a/template.py b/template.py index 58e222f..c8cd086 100644 --- a/template.py +++ b/template.py @@ -17,7 +17,9 @@ When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries. -Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema. +(CONTEXT IS NOT KNOWN TO USER) it is provided to you as a reference to generate SQL code. + +Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code based on the Context provided. Make sure that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema. **You are only required to write one SQL query per question.** If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. @@ -28,7 +30,14 @@ Write your response in markdown format. -User: {question} +Do not worry about access to the database or the schema details. The context provided is sufficient to generate the SQL code. The Sql code is not expected to run on any database. + +User Question: \n {question} + + +\n +Context - (Schema Details): +\n {context} Assistant: diff --git a/ui/sidebar.md b/ui/sidebar.md index d41df42..c6de714 100644 --- a/ui/sidebar.md +++ b/ui/sidebar.md @@ -12,7 +12,7 @@ SnowChat is an intuitive and user-friendly application that allows you to intera Here are some example queries you can try with SnowChat: -- Show me the total revenue for each product category. +- Write SQL code to show me the total revenue for each product category. - Who are the top 10 customers by sales? - What is the average order value for each region? - How many orders were placed last week? diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 30e3446..2b0ccbc 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -11,15 +11,18 @@ image_url + "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z" ) - +user_url = image_url + "cat-with-sunglasses.png" +claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z" def get_model_url(model_name): if "gpt" in model_name.lower(): return openai_url + elif "claude" in model_name.lower(): + return claude_url + elif "mixtral" in model_name.lower(): + return mistral_url elif "gemini" in model_name.lower(): return gemini_url - elif "mistral" in model_name.lower(): - return mistral_url return mistral_url @@ -57,7 +60,7 @@ def message_func(text, is_user=False, is_df=False, model="gpt"): avatar_url = model_url if is_user: - avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortFlat&accessoriesType=Prescription01&hairColor=Auburn&facialHairType=BeardLight&facialHairColor=Black&clotheType=Hoodie&clotheColor=PastelBlue&eyeType=Squint&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Tanned" + avatar_url = user_url message_alignment = "flex-end" message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" avatar_class = "user-avatar"