-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path06-ragfusion.py
117 lines (96 loc) · 4.13 KB
/
06-ragfusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import dotenv
import bs4
from operator import itemgetter
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.load import dumps, loads
from langchain_core.output_parsers import StrOutputParser
def reciprocal_rank_fusion(results: list[list], k=60):
""" Reciprocal_rank_fusion that takes multiple lists of ranked documents
and an optional parameter k used in the RRF formula """
# Initialize a dictionary to hold fused scores for each unique document
fused_scores = {}
# Iterate through each list of ranked documents
for docs in results:
# Iterate through each document in the list, with its rank (position in the list)
for rank, doc in enumerate(docs):
# Convert the document to a string format to use as a key (assumes documents can be serialized to JSON)
doc_json = dumps(doc)
# If the document is not yet in the fused_scores dictionary, add it with an initial score of 0
if doc_json not in fused_scores:
fused_scores[doc_json] = 0
# Retrieve the current score of the document, if any
previous_score = fused_scores[doc_json]
# Update the score of the document using the RRF formula: 1 / (rank + k)
fused_scores[doc_json] += 1 / (rank + k)
# Sort the documents based on their fused scores in descending order to get the final reranked results
reranked_results = [
(loads(doc), score)
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]
# Return the reranked results as a list of tuples, each containing the document and its fused score
return reranked_results
if __name__ == "__main__":
dotenv.load_dotenv()
loader = WebBaseLoader(
web_paths=(
"https://lilianweng.github.io/posts/2023-06-23-agent/",
),
bs_kwargs=dict(
parse_only=bs4.SoupStrainer(
class_=("post-content", "post-title", "post-header")
)
)
)
blog_docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=300,
chunk_overlap=50
)
splits = text_splitter.split_documents(blog_docs)
vectorstore = Chroma.from_documents(
documents=splits,
embedding=GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
)
retriever = vectorstore.as_retriever(
search_kwargs=dict(
k=10, # 이러면 가장 가까운 10개만 가져옴 - default = 4
)
)
question = "What is task decomposition for LLM agents?"
# RAG-Fusion Related
template = """You are a helpful assistant that generates multiple search queries based on a single input query. \n
Generate multiple search queries related to: {question} \n
Output (10 queries):"""
prompt_rag_fusion = ChatPromptTemplate.from_template(template)
generate_queries = (
prompt_rag_fusion
| ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
| StrOutputParser()
| (lambda x: x.split("\n\n"))
)
retrieval_chain_rag_fusion = generate_queries | retriever.map() | reciprocal_rank_fusion
docs = retrieval_chain_rag_fusion.invoke({"question": question})
# for doc in docs:
# print(doc)
# print("----")
# RAG
template = """Answer the following question based on this context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
final_rag_chain = (
dict(
context=retrieval_chain_rag_fusion,
question=itemgetter("question")
)
| prompt
| ChatGoogleGenerativeAI(model="gemini-1.5-flash",temperature=0)
| StrOutputParser()
)
print(final_rag_chain.invoke({"question": question}))