Skip to content

Commit

Permalink
Add qwen 2 72B
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Jun 7, 2024
1 parent 8392eed commit 6405509
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 89 deletions.
132 changes: 47 additions & 85 deletions chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class ModelConfig(BaseModel):

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
if v not in ["gpt", "llama", "claude", "mixtral8x7b", "arctic"]:
valid_model_types = ["qwen", "llama", "claude", "mixtral8x7b", "arctic"]
if v not in valid_model_types:
raise ValueError(f"Unsupported model type: {v}")
return v

Expand All @@ -43,85 +44,47 @@ def __init__(self, config: ModelConfig):
self.model_type = config.model_type
self.secrets = config.secrets
self.callback_handler = config.callback_handler
account_tag = self.secrets["CF_ACCOUNT_TAG"]
self.gateway_url = (
f"https://gateway.ai.cloudflare.com/v1/{account_tag}/k-1-gpt/openai"
)
self.setup()

def setup(self):
if self.model_type == "gpt":
self.setup_gpt()
elif self.model_type == "claude":
self.setup_claude()
elif self.model_type == "mixtral8x7b":
self.setup_mixtral_8x7b()
elif self.model_type == "llama":
self.setup_llama()
elif self.model_type == "arctic":
self.setup_arctic()

def setup_gpt(self):
self.llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0.2,
api_key=self.secrets["OPENAI_API_KEY"],
max_tokens=1000,
callbacks=[self.callback_handler],
streaming=True,
# base_url=self.gateway_url,
)

def setup_mixtral_8x7b(self):
self.llm = ChatOpenAI(
model_name="mixtral-8x7b-32768",
temperature=0.2,
api_key=self.secrets["GROQ_API_KEY"],
max_tokens=3000,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://api.groq.com/openai/v1",
)

def setup_claude(self):
self.llm = ChatOpenAI(
model_name="anthropic/claude-3-haiku",
temperature=0.1,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=700,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
default_headers={
"HTTP-Referer": "https://snowchat.streamlit.app/",
"X-Title": "Snowchat",
self.llm = self._setup_llm()

def _setup_llm(self):
model_config = {
"qwen": {
"model_name": "qwen/qwen-2-72b-instruct",
"api_key": self.secrets["OPENROUTER_API_KEY"],
"base_url": "https://openrouter.ai/api/v1",
},
)

def setup_llama(self):
self.llm = ChatOpenAI(
model_name="meta-llama/llama-3-70b-instruct",
temperature=0.1,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=700,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
default_headers={
"HTTP-Referer": "https://snowchat.streamlit.app/",
"X-Title": "Snowchat",
"claude": {
"model_name": "anthropic/claude-3-haiku",
"api_key": self.secrets["OPENROUTER_API_KEY"],
"base_url": "https://openrouter.ai/api/v1",
},
)
"mixtral8x7b": {
"model_name": "mixtral-8x7b-32768",
"api_key": self.secrets["GROQ_API_KEY"],
"base_url": "https://api.groq.com/openai/v1",
},
"llama": {
"model_name": "meta-llama/llama-3-70b-instruct",
"api_key": self.secrets["OPENROUTER_API_KEY"],
"base_url": "https://openrouter.ai/api/v1",
},
"arctic": {
"model_name": "snowflake/snowflake-arctic-instruct",
"api_key": self.secrets["OPENROUTER_API_KEY"],
"base_url": "https://openrouter.ai/api/v1",
},
}

def setup_arctic(self):
self.llm = ChatOpenAI(
model_name="snowflake/snowflake-arctic-instruct",
config = model_config[self.model_type]

return ChatOpenAI(
model_name=config["model_name"],
temperature=0.1,
api_key=self.secrets["OPENROUTER_API_KEY"],
api_key=config["api_key"],
max_tokens=700,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
base_url=config["base_url"],
default_headers={
"HTTP-Referer": "https://snowchat.streamlit.app/",
"X-Title": "Snowchat",
Expand Down Expand Up @@ -154,7 +117,7 @@ def _combine_documents(
return conversational_qa_chain


def load_chain(model_name="GPT-3.5", callback_handler=None):
def load_chain(model_name="qwen", callback_handler=None):
embeddings = OpenAIEmbeddings(
openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
)
Expand All @@ -165,17 +128,16 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
query_name="v_match_documents",
)

if "GPT-3.5" in model_name:
model_type = "gpt"
elif "mixtral 8x7b" in model_name.lower():
model_type = "mixtral8x7b"
elif "claude" in model_name.lower():
model_type = "claude"
elif "llama" in model_name.lower():
model_type = "llama"
elif "arctic" in model_name.lower():
model_type = "arctic"
else:
model_type_mapping = {
"qwen 2-72b": "qwen",
"mixtral 8x7b": "mixtral8x7b",
"claude-3 haiku": "claude",
"llama 3-70b": "llama",
"snowflake arctic": "arctic",
}

model_type = model_type_mapping.get(model_name.lower())
if model_type is None:
raise ValueError(f"Unsupported model name: {model_name}")

config = ModelConfig(
Expand Down
8 changes: 7 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
st.caption("Talk your way through data")
model = st.radio(
"",
options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5", "Snowflake Arctic"],
options=[
"Claude-3 Haiku",
"Mixtral 8x7B",
"Llama 3-70B",
"Qwen 2-72B",
"Snowflake Arctic",
],
index=0,
horizontal=True,
)
Expand Down
11 changes: 8 additions & 3 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/"
gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z"
mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z"
mistral_url = (
image_url
+ "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z"
)
openai_url = (
image_url
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-05-07T21%3A18%3A44.079Z"
Expand All @@ -16,10 +19,12 @@
claude_url = image_url + "Claude.png?t=2024-05-07T21%3A16%3A17.252Z"
meta_url = image_url + "meta-logo.webp?t=2024-05-07T21%3A18%3A12.286Z"
snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z"
qwen_url = image_url + "qwen.png?t=2024-06-07T08%3A51%3A36.363Z"


def get_model_url(model_name):
if "gpt" in model_name.lower():
return openai_url
if "qwen" in model_name.lower():
return qwen_url
elif "claude" in model_name.lower():
return claude_url
elif "llama" in model_name.lower():
Expand Down

0 comments on commit 6405509

Please sign in to comment.