Skip to content

Commit

Permalink
use mistral from Groq
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Mar 1, 2024
1 parent 709b334 commit e27441a
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 68 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,10 @@ archived_logs/
build/
snowchat.egg-info/

chroma_db
chroma_db

pplx.py

test.json
test.*
app.py
60 changes: 19 additions & 41 deletions chain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Callable, Dict, Optional

import boto3
import streamlit as st
from langchain.chat_models import BedrockChat, ChatOpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import SupabaseVectorStore
Expand Down Expand Up @@ -34,7 +33,7 @@ class ModelConfig(BaseModel):

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
if v not in ["gpt", "codellama", "mistral"]:
if v not in ["gpt", "mistral", "gemini"]:
raise ValueError(f"Unsupported model type: {v}")
return v

Expand All @@ -53,8 +52,8 @@ def __init__(self, config: ModelConfig):
def setup(self):
if self.model_type == "gpt":
self.setup_gpt()
elif self.model_type == "codellama":
self.setup_codellama()
elif self.model_type == "gemini":
self.setup_gemini()
elif self.model_type == "mistral":
self.setup_mixtral()

Expand All @@ -63,59 +62,38 @@ def setup_gpt(self):
model_name="gpt-3.5-turbo-0125",
temperature=0.2,
api_key=self.secrets["OPENAI_API_KEY"],
max_tokens=500,
max_tokens=1000,
callbacks=[self.callback_handler],
streaming=True,
base_url=self.gateway_url,
)

def setup_mixtral(self):
self.llm = ChatOpenAI(
model_name="mistralai/mistral-medium",
model_name="mixtral-8x7b-32768",
temperature=0.2,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=500,
api_key=self.secrets["GROQ_API_KEY"],
max_tokens=3000,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
base_url="https://api.groq.com/openai/v1",
)

def setup_codellama(self):
def setup_gemini(self):
self.llm = ChatOpenAI(
model_name="codellama/codellama-70b-instruct",
model_name="google/gemini-pro",
temperature=0.2,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=500,
max_tokens=1200,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
default_headers={
"HTTP-Referer": "https://snowchat.streamlit.app/",
"X-Title": "Snowchat",
},
)

# def setup_claude(self):
# bedrock_runtime = boto3.client(
# service_name="bedrock-runtime",
# aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"],
# aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"],
# region_name="us-east-1",
# )
# parameters = {
# "max_tokens_to_sample": 1000,
# "stop_sequences": [],
# "temperature": 0,
# "top_p": 0.9,
# }
# self.q_llm = BedrockChat(
# model_id="anthropic.claude-instant-v1", client=bedrock_runtime
# )

# self.llm = BedrockChat(
# model_id="anthropic.claude-instant-v1",
# client=bedrock_runtime,
# callbacks=[self.callback_handler],
# streaming=True,
# model_kwargs=parameters,
# )

def get_chain(self, vectorstore):
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
Expand Down Expand Up @@ -153,12 +131,12 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
query_name="v_match_documents",
)

if "codellama" in model_name.lower():
model_type = "codellama"
elif "GPT-3.5" in model_name:
if "GPT-3.5" in model_name:
model_type = "gpt"
elif "mistral" in model_name.lower():
model_type = "mistral"
elif "gemini" in model_name.lower():
model_type = "gemini"
else:
raise ValueError(f"Unsupported model name: {model_name}")

Expand Down
36 changes: 31 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,43 @@
chat_history = []
snow_ddl = Snowddl()

st.title("snowChat")
gradient_text_html = """
<style>
.gradient-text {
font-weight: bold;
background: -webkit-linear-gradient(left, red, orange);
background: linear-gradient(to right, red, orange);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
display: inline;
font-size: 3em;
}
</style>
<div class="gradient-text">snowChat</div>
"""

st.markdown(gradient_text_html, unsafe_allow_html=True)

st.caption("Talk your way through data")
model = st.radio(
"",
options=["GPT-3.5", "♾️ codellama", "👑 Mistral"],
options=["GPT-3.5 - OpenAI", "Gemini 1.5 - Openrouter", "Mistral 8x7B - Groq"],
index=0,
horizontal=True,
)
st.session_state["model"] = model

if "toast_shown" not in st.session_state:
st.session_state["toast_shown"] = False

# Show the toast only if it hasn't been shown before
if not st.session_state["toast_shown"]:
st.toast("The snowflake data retrieval is disabled for now.", icon="👋")
st.session_state["toast_shown"] = True

if st.session_state["model"] == "👑 Mistral 8x7B - Groq":
st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️")

INITIAL_MESSAGE = [
{"role": "user", "content": "Hi!"},
{
Expand All @@ -38,10 +65,8 @@
with open("ui/styles.md", "r") as styles_file:
styles_content = styles_file.read()

# Display the DDL for the selected table
st.sidebar.markdown(sidebar_content)

# Create a sidebar with a dropdown menu
selected_table = st.sidebar.selectbox(
"Select a table:", options=list(snow_ddl.ddl_dict.keys())
)
Expand Down Expand Up @@ -81,9 +106,10 @@
message["content"],
True if message["role"] == "user" else False,
True if message["role"] == "data" else False,
model,
)

callback_handler = StreamlitUICallbackHandler()
callback_handler = StreamlitUICallbackHandler(model)

chain = load_chain(st.session_state["model"], callback_handler)

Expand Down
10 changes: 8 additions & 2 deletions ui/styles.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
background-color: white;
z-index: 100;
}
h1 {
font-family: 'Roboto Slab', serif;
h1, h2 {
font-weight: bold;
background: -webkit-linear-gradient(left, red, orange);
background: linear-gradient(to right, red, orange);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
display: inline;
font-size: 3em;
}
.user-avatar {
float: right;
Expand Down
53 changes: 34 additions & 19 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler

image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/"
gemini_url = image_url + "google-gemini-icon.png?t=2024-03-01T07%3A25%3A59.637Z"
mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png"
openai_url = (
image_url
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z"
)


def get_model_url(model_name):
if "gpt" in model_name.lower():
return openai_url
elif "gemini" in model_name.lower():
return gemini_url
elif "mistral" in model_name.lower():
return mistral_url
return mistral_url


def format_message(text):
"""
Expand All @@ -26,7 +44,7 @@ def format_message(text):
return formatted_text


def message_func(text, is_user=False, is_df=False):
def message_func(text, is_user=False, is_df=False, model="gpt"):
"""
This function is used to display the messages in the chatbot UI.
Expand All @@ -35,6 +53,9 @@ def message_func(text, is_user=False, is_df=False):
is_user (bool): Whether the message is from the user or not.
is_df (bool): Whether the message is a dataframe or not.
"""
model_url = get_model_url(model)

avatar_url = model_url
if is_user:
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortFlat&accessoriesType=Prescription01&hairColor=Auburn&facialHairType=BeardLight&facialHairColor=Black&clotheType=Hoodie&clotheColor=PastelBlue&eyeType=Squint&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Tanned"
message_alignment = "flex-end"
Expand All @@ -45,13 +66,12 @@ def message_func(text, is_user=False, is_df=False):
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%; font-size: 14px;">
{text} \n </div>
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 40px; height: 40px;" />
</div>
""",
unsafe_allow_html=True,
)
else:
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
message_alignment = "flex-start"
message_bg_color = "#71797E"
avatar_class = "bot-avatar"
Expand All @@ -60,7 +80,7 @@ def message_func(text, is_user=False, is_df=False):
st.write(
f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
<img src="{model_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
</div>
""",
unsafe_allow_html=True,
Expand All @@ -73,8 +93,8 @@ def message_func(text, is_user=False, is_df=False):
st.write(
f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
<div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%; font-size: 14px;">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 30px; height: 30px;" />
<div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; margin-left: 5px; max-width: 75%; font-size: 14px;">
{text} \n </div>
</div>
""",
Expand All @@ -83,11 +103,13 @@ def message_func(text, is_user=False, is_df=False):


class StreamlitUICallbackHandler(BaseCallbackHandler):
def __init__(self):
def __init__(self, model):
self.token_buffer = []
self.placeholder = st.empty()
self.has_streaming_ended = False
self.has_streaming_started = False
self.model = model
self.avatar_url = get_model_url(model)

def start_loading_message(self):
loading_message_content = self._get_bot_message_container("Thinking...")
Expand All @@ -109,17 +131,11 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):

def _get_bot_message_container(self, text):
"""Generate the bot's message container style for the given text."""
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
message_alignment = "flex-start"
message_bg_color = "#71797E"
avatar_class = "bot-avatar"
formatted_text = format_message(
text
) # Ensure this handles "Thinking..." appropriately.
formatted_text = format_message(text)
container_content = f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
<div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%; font-size: 14px;">
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: flex-start;">
<img src="{self.avatar_url}" class="bot-avatar" alt="avatar" style="width: 30px; height: 30px;" />
<div style="background: #71797E; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; margin-left: 5px; max-width: 75%; font-size: 14px;">
{formatted_text} \n </div>
</div>
"""
Expand All @@ -129,14 +145,13 @@ def display_dataframe(self, df):
"""
Display the dataframe in Streamlit UI within the chat container.
"""
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
message_alignment = "flex-start"
avatar_class = "bot-avatar"

st.write(
f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
<img src="{self.avatar_url}" class="{avatar_class}" alt="avatar" style="width: 30px; height: 30px;" />
</div>
""",
unsafe_allow_html=True,
Expand Down

0 comments on commit e27441a

Please sign in to comment.