From 75cef85f419cd8f9536a9276a9b295e15aa47ed1 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Fri, 19 Apr 2024 10:45:24 +1200 Subject: [PATCH] ad llama 3 --- chain.py | 15 +++++++-------- main.py | 27 +++++++++++++++------------ utils/snowchat_ui.py | 8 +++++--- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/chain.py b/chain.py index d66fc72..178c380 100644 --- a/chain.py +++ b/chain.py @@ -33,7 +33,7 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["gpt", "gemini", "claude", "mixtral8x7b"]: + if v not in ["gpt", "llama", "claude", "mixtral8x7b"]: raise ValueError(f"Unsupported model type: {v}") return v @@ -56,9 +56,8 @@ def setup(self): self.setup_claude() elif self.model_type == "mixtral8x7b": self.setup_mixtral_8x7b() - elif self.model_type == "gemini": - self.setup_gemini() - + elif self.model_type == "llama": + self.setup_llama() def setup_gpt(self): self.llm = ChatOpenAI( @@ -97,9 +96,9 @@ def setup_claude(self): }, ) - def setup_gemini(self): + def setup_llama(self): self.llm = ChatOpenAI( - model_name="google/gemini-pro-1.5", + model_name="meta-llama/llama-3-70b-instruct", temperature=0.1, api_key=self.secrets["OPENROUTER_API_KEY"], max_tokens=700, @@ -155,8 +154,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): model_type = "mixtral8x7b" elif "claude" in model_name.lower(): model_type = "claude" - elif "gemini" in model_name.lower(): - model_type = "gemini" + elif "llama" in model_name.lower(): + model_type = "llama" else: raise ValueError(f"Unsupported model name: {model_name}") diff --git a/main.py b/main.py index c313821..b527704 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,7 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["Claude-3 Haiku", "Mixtral 8x7B", "Gemini 1.5 Pro", "GPT-3.5"], + options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5"], index=0, horizontal=True, ) @@ -52,9 +52,9 @@ st.session_state["toast_shown"] = True # Show a warning if the model is rate-limited -if st.session_state['rate-limit']: +if st.session_state["rate-limit"]: st.toast("Probably rate limited.. Go easy folks", icon="⚠️") - st.session_state['rate-limit'] = False + st.session_state["rate-limit"] = False if st.session_state["model"] == "Mixtral 8x7B": st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️") @@ -181,12 +181,15 @@ def execute_sql(query, conn, retries=2): ) append_message(result.content) -if st.session_state["model"] == "Mixtral 8x7B" and st.session_state['messages'][-1]['content'] == "": - st.session_state['rate-limit'] = True - - # if get_sql(result): - # conn = SnowflakeConnection().get_session() - # df = execute_sql(get_sql(result), conn) - # if df is not None: - # callback_handler.display_dataframe(df) - # append_message(df, "data", True) +if ( + st.session_state["model"] == "Mixtral 8x7B" + and st.session_state["messages"][-1]["content"] == "" +): + st.session_state["rate-limit"] = True + + # if get_sql(result): + # conn = SnowflakeConnection().get_session() + # df = execute_sql(get_sql(result), conn) + # if df is not None: + # callback_handler.display_dataframe(df) + # append_message(df, "data", True) diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 65dce2e..72aaff8 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -13,14 +13,16 @@ ) user_url = image_url + "cat-with-sunglasses.png" claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z" +meta_url = image_url + "meta-logo.webp?t=2024-04-18T22%3A43%3A17.775Z" + def get_model_url(model_name): if "gpt" in model_name.lower(): return openai_url elif "claude" in model_name.lower(): return claude_url - elif "mixtral" in model_name.lower(): - return mistral_url + elif "llama" in model_name.lower(): + return meta_url elif "gemini" in model_name.lower(): return gemini_url return mistral_url @@ -119,7 +121,7 @@ def start_loading_message(self): self.placeholder.markdown(loading_message_content, unsafe_allow_html=True) def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs): - print("on llm bnew token ",token) + print("on llm bnew token ", token) if not self.has_streaming_started: self.has_streaming_started = True