Skip to content

Commit

Permalink
ad llama 3
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Apr 18, 2024
1 parent 019034a commit 75cef85
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
15 changes: 7 additions & 8 deletions chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ModelConfig(BaseModel):

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
if v not in ["gpt", "gemini", "claude", "mixtral8x7b"]:
if v not in ["gpt", "llama", "claude", "mixtral8x7b"]:
raise ValueError(f"Unsupported model type: {v}")
return v

Expand All @@ -56,9 +56,8 @@ def setup(self):
self.setup_claude()
elif self.model_type == "mixtral8x7b":
self.setup_mixtral_8x7b()
elif self.model_type == "gemini":
self.setup_gemini()

elif self.model_type == "llama":
self.setup_llama()

def setup_gpt(self):
self.llm = ChatOpenAI(
Expand Down Expand Up @@ -97,9 +96,9 @@ def setup_claude(self):
},
)

def setup_gemini(self):
def setup_llama(self):
self.llm = ChatOpenAI(
model_name="google/gemini-pro-1.5",
model_name="meta-llama/llama-3-70b-instruct",
temperature=0.1,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=700,
Expand Down Expand Up @@ -155,8 +154,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
model_type = "mixtral8x7b"
elif "claude" in model_name.lower():
model_type = "claude"
elif "gemini" in model_name.lower():
model_type = "gemini"
elif "llama" in model_name.lower():
model_type = "llama"
else:
raise ValueError(f"Unsupported model name: {model_name}")

Expand Down
27 changes: 15 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
st.caption("Talk your way through data")
model = st.radio(
"",
options=["Claude-3 Haiku", "Mixtral 8x7B", "Gemini 1.5 Pro", "GPT-3.5"],
options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5"],
index=0,
horizontal=True,
)
Expand All @@ -52,9 +52,9 @@
st.session_state["toast_shown"] = True

# Show a warning if the model is rate-limited
if st.session_state['rate-limit']:
if st.session_state["rate-limit"]:
st.toast("Probably rate limited.. Go easy folks", icon="⚠️")
st.session_state['rate-limit'] = False
st.session_state["rate-limit"] = False

if st.session_state["model"] == "Mixtral 8x7B":
st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️")
Expand Down Expand Up @@ -181,12 +181,15 @@ def execute_sql(query, conn, retries=2):
)
append_message(result.content)

if st.session_state["model"] == "Mixtral 8x7B" and st.session_state['messages'][-1]['content'] == "":
st.session_state['rate-limit'] = True

# if get_sql(result):
# conn = SnowflakeConnection().get_session()
# df = execute_sql(get_sql(result), conn)
# if df is not None:
# callback_handler.display_dataframe(df)
# append_message(df, "data", True)
if (
st.session_state["model"] == "Mixtral 8x7B"
and st.session_state["messages"][-1]["content"] == ""
):
st.session_state["rate-limit"] = True

# if get_sql(result):
# conn = SnowflakeConnection().get_session()
# df = execute_sql(get_sql(result), conn)
# if df is not None:
# callback_handler.display_dataframe(df)
# append_message(df, "data", True)
8 changes: 5 additions & 3 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
)
user_url = image_url + "cat-with-sunglasses.png"
claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z"
meta_url = image_url + "meta-logo.webp?t=2024-04-18T22%3A43%3A17.775Z"


def get_model_url(model_name):
if "gpt" in model_name.lower():
return openai_url
elif "claude" in model_name.lower():
return claude_url
elif "mixtral" in model_name.lower():
return mistral_url
elif "llama" in model_name.lower():
return meta_url
elif "gemini" in model_name.lower():
return gemini_url
return mistral_url
Expand Down Expand Up @@ -119,7 +121,7 @@ def start_loading_message(self):
self.placeholder.markdown(loading_message_content, unsafe_allow_html=True)

def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
print("on llm bnew token ",token)
print("on llm bnew token ", token)
if not self.has_streaming_started:
self.has_streaming_started = True

Expand Down

0 comments on commit 75cef85

Please sign in to comment.