Skip to content

Commit

Permalink
Add new env variable to define desired GPT model
Browse files Browse the repository at this point in the history
  • Loading branch information
zumpious committed Jan 18, 2025
1 parent a7e2c88 commit a656f00
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
OPENAI_API_KEY=
PDF_PATH=
VECTOR_DB_PATH=
VECTOR_DB_PATH=
GPT_MODEL=
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def init_session_state() -> None:
st.session_state.messages = []
if "qa_chain" not in st.session_state:
env_vars = get_env_vars()
st.session_state.qa_chain = setup_rag(env_vars['vector_db_path'])
st.session_state.qa_chain = setup_rag(env_vars['vector_db_path'], env_vars['gpt_model'])

def process_query(query: str, chat_history: List[Dict[str, str]], k_value: int, fetch_k: int) -> Dict[str, Any]:
"""Process user query using RAG system with debug information.
Expand Down
4 changes: 2 additions & 2 deletions src/rag/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain

def setup_rag(vector_store_path: str):
def setup_rag(vector_store_path: str, model_name: str = "gpt-4o-mini") -> ConversationalRetrievalChain:
"""Set up Retrieval-Augmented Generation (RAG) system.
Loads a FAISS vector store from disk, initializes OpenAI embeddings,
Expand All @@ -30,7 +30,7 @@ def setup_rag(vector_store_path: str):
allow_dangerous_deserialization=True
)

llm = ChatOpenAI(model_name="gpt-4", temperature=0)
llm = ChatOpenAI(model_name=model_name, temperature=0)
return ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
Expand Down
6 changes: 4 additions & 2 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

load_dotenv()

def get_env_vars():
def get_env_vars() -> dict:
"""Get environment variables needed for RAG system.
Returns:
dict: Environment variables including:
- pdf_path: Path to source PDF document
- vector_db_path: Path to FAISS vector store
- openai_key: OpenAI API key
- gpt_model: OpenAI model name to use
"""
return {
'pdf_path': os.getenv('PDF_PATH'),
'vector_db_path': os.getenv('VECTOR_DB_PATH'),
'openai_key': os.getenv('OPENAI_API_KEY')
'openai_key': os.getenv('OPENAI_API_KEY'),
'gpt_model': os.getenv('GPT_MODEL', 'gpt-4o-mini')
}

0 comments on commit a656f00

Please sign in to comment.