-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
178 lines (142 loc) · 6.49 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import streamlit as st
import asyncio
import uuid
# import sys
# import os
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from metadata_chatbot.agents.async_workflow import async_workflow
from metadata_chatbot.agents.react_agent import astream_input
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from streamlit_feedback import streamlit_feedback
import uuid
import warnings
warnings.filterwarnings('ignore')
@st.cache_resource
def load_checkpointer():
return MemorySaver()
async def stream_database_query(result):
response = result['type']
result_mapping = {
'intermediate_steps': lambda: result['content'],
'agg_pipeline': lambda: f"The MongoDB pipeline used to on the database is:\n`{result['content']}`",
'tool_response': lambda: f"Retrieved output from MongoDB:\n```json\n{result['content']}\n```",
'final_answer': lambda: result['content']
}
if response in result_mapping:
return result_mapping[response]()
return None
async def handle_database_query(query: str, chat_history: list):
try:
query = str(chat_history) + query
async for result in astream_input(query = query):
yield await stream_database_query(result)
except Exception as e:
yield f"An error has occured with the retrieval from DocDB: {e}. Try structuring your query another way."
async def yield_model_response(output_stream):
async for output in output_stream:
for key, value in output.items():
if key != "database_query":
yield value['messages'][0].content
else:
yield await handle_database_query(value)
async def answer_generation(query: str, chat_history: list, config: dict, model):
"""Main function to generate answers."""
inputs = {
"messages": chat_history,
}
model_output = model.astream(inputs, config)
async for response in yield_model_response(model_output):
yield response
async def answer_generation(query: str, chat_history: list, config:dict, model):
inputs = {
"messages": chat_history,
}
async for output in model.astream(inputs, config):
for key, value in output.items():
if key != "database_query":
yield value['messages'][0].content
else:
try:
query = str(chat_history) + query
async for result in astream_input(query = query):
response = result['type']
if response == 'intermediate_steps':
yield result['content']
if response == 'agg_pipeline':
yield "The MongoDB pipeline used to on the database is:"
yield f"`{result['content']}`"
if response == 'tool_response':
yield "Retrieved output from MongoDB:"
yield f"""```json
{result['content']}
```"""
if response == 'final_answer':
yield result['content']
except Exception as e:
yield f"An error has occured with the retrieval from DocDB: {e}. Try structuring your query another way."
def set_query(query):
st.session_state.query = query
async def main():
st.title("GAMER: Generative Analysis of Metadata Retrieval")
if "thread_id" not in st.session_state:
st.session_state.thread_id = ''
st.session_state.thread_id = str(uuid.uuid4())
checkpointer = load_checkpointer()
model = async_workflow.compile(checkpointer=checkpointer)
if 'query' not in st.session_state:
st.session_state.query = ''
st.info(
"Ask a question about the AIND metadata! Please note that it will take a couple of seconds to generate an answer. Type a query to start or pick one of these suggestions:"
)
examples = [
"What are the modalities that exist in the database? What are the least and most common ones?",
"What is the MongoDB query to find the injections used in SmartSPIM_675387_2023-05-23_23-05-56",
"Can you list all the procedures performed on 662616, including their start and end dates?"
]
columns = st.columns(len(examples))
for i, column in enumerate(columns):
with column:
st.button(examples[i], on_click = set_query, args=[examples[i]])
message = st.chat_message("assistant")
message.write("Hello! How can I help you?")
user_query = st.chat_input("Message GAMER")
if user_query:
st.session_state.query = user_query
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
with st.chat_message("user"):
st.markdown(message.content)
else:
with st.chat_message("assistant"):
st.markdown(message.content)
query = st.session_state.query
if query is not None and query != '':
st.session_state.messages.append(HumanMessage(query))
with st.chat_message("user"):
st.markdown(query)
with st.chat_message("assistant"):
config = {"configurable": {"thread_id": st.session_state.thread_id}}
prev = None
generation = None
chat_history = st.session_state.messages
with st.status("Generating answer...", expanded = True) as status:
async for result in answer_generation(query, chat_history, config, model):
if prev != None:
st.markdown(prev)
prev = result
generation = prev
status.update(label = "Answer generation successful.")
final_response = st.empty()
final_response.markdown(generation)
feedback = streamlit_feedback(feedback_type="thumbs",
optional_text_label="[Optional] Please provide an explanation for your choice",)
print(feedback)
st.session_state.messages.append(AIMessage(generation))
# response = await llm.streamlit_astream(query, unique_id = unique_id)
# st.markdown(response)
st.session_state.query = ''
if __name__ == "__main__":
asyncio.run(main())