-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdocuments_qa.py
137 lines (105 loc) · 3.88 KB
/
documents_qa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from dotenv import load_dotenv
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI
import chainlit as cl
from chainlit.types import AskFileResponse
from apikey import OPENAI_API_KEY
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
embeddings = OpenAIEmbeddings()
welcome_message = """ Welcome to LLM Pdf QA Demo!
To start:
1. Upload Pdf or text file
2. Ask a question about file
"""
def process_file(file: AskFileResponse):
import tempfile
if file.type == "text/plain":
Loader = TextLoader
elif file.type == "application/pdf":
Loader = PyPDFLoader
with tempfile.NamedTemporaryFile() as tempfile:
tempfile.write(file.content)
loader = Loader(tempfile.name)
documents = loader.load()
docs = text_splitter.split_documents(documents)
for i, doc in enumerate(docs):
doc.metadata["source"] = f"source_{i}"
return docs
def get_docsearch(file: AskFileResponse):
docs = process_file(file)
# Save data in the user session
cl.user_session.set("docs", docs)
# Create a unique namespace for the file
docsearch = Chroma.from_documents(
docs, embeddings
)
return docsearch
@cl.on_chat_start
async def start():
load_dotenv()
#Sending local file path
await cl.Message(content="Now chat with your Pdfs.").send()
files = None
while files is None:
files = await cl.AskFileMessage(
content= welcome_message,
accept=["text/plain", "application/pdf"],
max_size_mb=20,
timeout=180
).send()
file = files[0]
msg = cl.Message(content=f"Processing `{file.name}`...")
await msg.send()
docsearch = await cl.make_async(get_docsearch)(file)
chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(temperature=0, streaming=True),
chain_type="stuff",
retriever=docsearch.as_retriever(max_tokens_limit=4097)
)
# inform user that system is live
msg.content = f"`{file.name}` processed. You may now ask questions!"
await msg.update()
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message, callbacks=[cb])
answer = res["answer"]
sources = res["sources"].strip()
source_elements = []
# Get the documents from the user session
docs = cl.user_session.get("docs")
metadatas = [doc.metadata for doc in docs]
all_sources = [m["source"] for m in metadatas]
if sources:
found_sources = []
# Add the sources to the message
for source in sources.split(","):
source_name = source.strip().replace(".","")
# Get the index of source
try:
index = all_sources.index(source_name)
except ValueError:
continue
text = docs[index].page_content
found_sources.append(source_name)
source_elements.append(cl.Text(content=text, name=source_name))
if found_sources:
answer += f"\nSources: {', '.join(found_sources)}"
else:
answer += "\nNo sources found"
if cb.has_streamed_final_answer:
cb.final_stream.elements = source_elements
await cb.final_stream.update()
else:
await cl.Message(content=answer, elements=source_elements).send()