diff --git a/chain.py b/chain.py index 2bff972..d66fc72 100644 --- a/chain.py +++ b/chain.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, Optional import streamlit as st -from langchain.chat_models import ChatOpenAI +from langchain_community.chat_models import ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms import OpenAI from langchain.vectorstores import SupabaseVectorStore @@ -33,7 +33,7 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["gpt", "mixtral8x22b", "claude", "mixtral8x7b"]: + if v not in ["gpt", "gemini", "claude", "mixtral8x7b"]: raise ValueError(f"Unsupported model type: {v}") return v @@ -56,8 +56,8 @@ def setup(self): self.setup_claude() elif self.model_type == "mixtral8x7b": self.setup_mixtral_8x7b() - elif self.model_type == "mixtral8x22b": - self.setup_mixtral_8x22b() + elif self.model_type == "gemini": + self.setup_gemini() def setup_gpt(self): @@ -97,9 +97,9 @@ def setup_claude(self): }, ) - def setup_mixtral_8x22b(self): + def setup_gemini(self): self.llm = ChatOpenAI( - model_name="mistralai/mixtral-8x22b", + model_name="google/gemini-pro-1.5", temperature=0.1, api_key=self.secrets["OPENROUTER_API_KEY"], max_tokens=700, @@ -155,8 +155,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): model_type = "mixtral8x7b" elif "claude" in model_name.lower(): model_type = "claude" - elif "mixtral 8x22b" in model_name.lower(): - model_type = "mixtral8x22b" + elif "gemini" in model_name.lower(): + model_type = "gemini" else: raise ValueError(f"Unsupported model name: {model_name}") diff --git a/main.py b/main.py index 644ddc0..c313821 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,7 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["Claude-3 Haiku", "Mixtral 8x7B", "Mixtral 8x22B", "GPT-3.5"], + options=["Claude-3 Haiku", "Mixtral 8x7B", "Gemini 1.5 Pro", "GPT-3.5"], index=0, horizontal=True, ) diff --git a/requirements.txt b/requirements.txt index 8c73f89..4775929 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ -langchain==0.1.5 +langchain==0.1.15 pandas==1.5.0 pydantic==1.10.8 snowflake_snowpark_python==1.5.0 snowflake-snowpark-python[pandas] streamlit==1.31.0 -supabase==1.0.3 +supabase==2.4.1 unstructured==0.7.12 tiktoken==0.5.2 -openai==1.11.0 +openai==1.17.0 black==23.3.0 boto3==1.28.57 -langchain_openai==0.0.5 \ No newline at end of file +langchain_openai==0.1.2 +langchain-community==0.0.32 +langchain-core==0.1.41 \ No newline at end of file diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 2b0ccbc..65dce2e 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -119,6 +119,7 @@ def start_loading_message(self): self.placeholder.markdown(loading_message_content, unsafe_allow_html=True) def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs): + print("on llm bnew token ",token) if not self.has_streaming_started: self.has_streaming_started = True