-
Notifications
You must be signed in to change notification settings - Fork 0
/
technical-documentation-ai.py
168 lines (130 loc) · 6.26 KB
/
technical-documentation-ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Importing required libraries and modules
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.retrievers.web_research import WebResearchRetriever
import os
import logging
def setup_logging():
logging.basicConfig()
logging.getLogger("langchain.retrievers.web_research").setLevel(logging.INFO)
setup_logging()
# Configuring the Streamlit page appearance
st.set_page_config(
page_title="Technical Documentation AI Bot", page_icon="🤖", layout="wide"
)
def settings():
# Importing necessary modules for creating vector stores and embeddings
import faiss
from langchain.vectorstores import FAISS
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.docstore import InMemoryDocstore
# Initializing the OpenAI embeddings model
embeddings_model = OpenAIEmbeddings()
embedding_size = 1536
# Creating a FAISS index for storing embeddings
index = faiss.IndexFlatL2(embedding_size)
# Creating a FAISS vector store using the embeddings and index
vectorstore_public = FAISS(
embeddings_model.embed_query, index, InMemoryDocstore({}), {}
)
# Initializing the language model (GPT-4 in this case)
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model_name="gpt-4")
# Setting up a Google Search API Wrapper for web retrieval
from langchain.utilities import GoogleSearchAPIWrapper
search = GoogleSearchAPIWrapper()
# Initializing the Web Research Retriever with necessary components
web_retriever = WebResearchRetriever.from_llm(
vectorstore=vectorstore_public, llm=llm, search=search, num_search_results=3
)
return web_retriever, llm
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
# Initializing the StreamHandler with a container to display text and initial text
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
# Updating the displayed text as new tokens are generated by the language model
self.text += token
self.container.markdown(self.text + "▌")
class PrintRetrievalHandler(BaseCallbackHandler):
def __init__(self, container):
# Initializing the PrintRetrievalHandler with a container to display retrieved documents
self.container = container.expander("Context Retrieval")
def on_retriever_start(self, query: str, **kwargs):
# Displaying the question/query when the retrieval starts
self.container.write(f"**Question:** {query}")
def on_retriever_end(self, documents, **kwargs):
# Displaying the retrieved documents when the retrieval ends
for idx, doc in enumerate(documents):
source = doc.metadata["source"]
self.container.write(f"**Results from {source}**")
self.container.text(doc.page_content)
# Displaying the main header and information about the application
st.header("Technical Documentation AI Chat Bot")
st.info(
"I can answer technical questions in real time by checking documentation for various tools and services."
)
# Displaying the list of supported documentation
st.subheader("Supported Documentation")
documentation_links = """
1. [**AWS Documentation**](https://docs.aws.amazon.com)
2. [**dbt Documentation**](https://getdbt.com)
3. [**dbt Project Evaluator**](https://dbt-labs.github.io/dbt-project-evaluator)
4. [**Fivetran Documentation**](https://fivetran.com/docs)
5. [**Looker Documentation**](https://cloud.google.com/looker/docs)
6. [**Prefect Documentation**](https://docs.prefect.io)
7. [**Python (Langchain) Documentation**](https://python.langchain.com/docs)
8. [**Snowflake Documentation**](https://docs.snowflake.com)
9. [**Streamlit Documentation**](https://docs.streamlit.io)
"""
st.markdown(documentation_links)
# Initializing the retriever and language model if they haven't been initialized yet
if "retriever" not in st.session_state:
st.session_state["retriever"], st.session_state["llm"] = settings()
web_retriever = st.session_state.retriever
llm = st.session_state.llm
# Initializing a list to store messages if it doesn't exist yet
if "messages" not in st.session_state:
st.session_state.messages = []
# Displaying all the previous messages in the chat
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Input field for the user to ask a question
if question := st.chat_input("Ask a question:"):
try:
# Storing the user's question and displaying it in the chat
st.session_state.messages.append({"role": "user", "content": question})
with st.chat_message("user"):
st.markdown(question)
# Display loading wheel while processing the question
with st.spinner("Processing your question..."):
# Initializing the QA chain with the language model and the web retriever
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
llm, retriever=web_retriever
)
# Setting up callback handlers to manage the display of retrieval results and the generated answer
retrieval_streamer_cb = PrintRetrievalHandler(st.container())
answer = st.empty()
stream_handler = StreamHandler(answer, initial_text="`Answer:`\n\n")
# Executing the QA chain to generate and display the answer
result = qa_chain(
{"question": question},
callbacks=[retrieval_streamer_cb, stream_handler],
)
# Storing the full response and displaying it in the chat
full_response = "`Answer:`\n\n" + result["answer"]
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
with st.chat_message("assistant"):
st.markdown(full_response)
# Displaying the sources of the information provided in the answer
st.info("`Sources:`\n\n" + result["sources"])
except Exception as e:
st.error(
"Sorry, an error occurred while processing your question. Please try again later."
)
logging.error("Error processing question: %s", str(e))