-
Notifications
You must be signed in to change notification settings - Fork 13
/
create_index.py
56 lines (43 loc) · 1.94 KB
/
create_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv
# Load environment variables from a .env file
load_dotenv()
RAG_SEARCH_PROMPT_TEMPLATE = """
Using the following pieces of retrieved context, answer the question comprehensively and concisely.
Ensure your response fully addresses the question based on the given context.
**IMPORTANT:**
Just provide the answer and never mention or refer to having access to the external context or information in your answer.
If you are unable to determine the answer from the provided context, state 'I don't know.'
Question: {question}
Context: {context}
"""
print("Loading & Chunking Docs...")
loader = TextLoader("./data/agency.txt")
docs = loader.load()
doc_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
doc_chunks = doc_splitter.split_documents(docs)
print("Creating vector embeddings...")
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
vectorstore = Chroma.from_documents(doc_chunks, embeddings, persist_directory="db")
# Semantic vector search
vectorstore_retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# Test RAG chain
print("Test RAG chain...")
prompt = ChatPromptTemplate.from_template(RAG_SEARCH_PROMPT_TEMPLATE)
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.1)
rag_chain = (
{"context": vectorstore_retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
query = "What are your pricing options?"
result = rag_chain.invoke(query)
print(f"Question: {query}")
print(f"Answer: {result}")