diff --git a/chain.py b/chain.py index d896c19..5a55bcb 100644 --- a/chain.py +++ b/chain.py @@ -4,10 +4,9 @@ import streamlit as st from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain -from langchain.chat_models import ChatOpenAI +from langchain.chat_models import ChatOpenAI, BedrockChat from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms import OpenAI -from langchain.llms.bedrock import Bedrock from langchain.vectorstores import SupabaseVectorStore from pydantic import BaseModel, validator from supabase.client import Client, create_client @@ -36,8 +35,10 @@ def __init__(self, config: ModelConfig): self.model_type = config.model_type self.secrets = config.secrets self.callback_handler = config.callback_handler - account_tag = self.secrets['CF_ACCOUNT_TAG'] - self.gateway_url = f"https://gateway.ai.cloudflare.com/v1/{account_tag}/k-1-gpt/openai" + account_tag = self.secrets["CF_ACCOUNT_TAG"] + self.gateway_url = ( + f"https://gateway.ai.cloudflare.com/v1/{account_tag}/k-1-gpt/openai" + ) self.setup() def setup(self): @@ -54,7 +55,7 @@ def setup_gpt(self): api_key=self.secrets["OPENAI_API_KEY"], model_name="gpt-3.5-turbo-16k", max_tokens=500, - base_url=self.gateway_url + base_url=self.gateway_url, ) self.llm = ChatOpenAI( @@ -64,7 +65,7 @@ def setup_gpt(self): max_tokens=500, callbacks=[self.callback_handler], streaming=True, - base_url=self.gateway_url + base_url=self.gateway_url, ) def setup_mixtral(self): @@ -99,10 +100,11 @@ def setup_claude(self): "temperature": 0, "top_p": 0.9, } - self.q_llm = Bedrock( + self.q_llm = BedrockChat( model_id="anthropic.claude-instant-v1", client=bedrock_runtime ) - self.llm = Bedrock( + + self.llm = BedrockChat( model_id="anthropic.claude-instant-v1", client=bedrock_runtime, callbacks=[self.callback_handler],