From 1e135a136f6e3d28c2201158683815fb3265aa0f Mon Sep 17 00:00:00 2001 From: sreyakumar <121137643+sreyakumar@users.noreply.github.com> Date: Tue, 26 Nov 2024 14:55:26 -0800 Subject: [PATCH] added streaming on streamlit --- app.py | 21 +++++---- src/metadata_chatbot/agents/GAMER.py | 43 +++++++++++++------ src/metadata_chatbot/agents/async_workflow.py | 12 +++++- 3 files changed, 54 insertions(+), 22 deletions(-) diff --git a/app.py b/app.py index 54f7b1d..cc8ad5f 100644 --- a/app.py +++ b/app.py @@ -1,20 +1,23 @@ # Import the Streamlit library import streamlit as st -import asyncio - import sys import os + +import asyncio + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) -import metadata_chatbot.agents.docdb_retriever -import metadata_chatbot.agents.agentic_graph -from metadata_chatbot.agents.async_workflow import astream +from metadata_chatbot.agents.GAMER import GAMER +import uuid + +#run on terminal with streamlit run c:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot/app.py [ARGUMENTS] -#run on terminal with streamlit run [ARGUMENTS] async def main(): -# Write a simple message to the app's webpage + llm = GAMER() + unique_id = str(uuid.uuid4()) + message = st.chat_message("assistant") message.write("Hello!") @@ -33,12 +36,12 @@ async def main(): st.markdown(prompt) # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) - response = await llm.ainvoke(prompt) + #response = await llm.ainvoke(prompt) with st.chat_message("assistant"): + response = await llm.streamlit_astream(prompt, unique_id = unique_id) st.markdown(response) - # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": response}) diff --git a/src/metadata_chatbot/agents/GAMER.py b/src/metadata_chatbot/agents/GAMER.py index 5df3ea3..364ad25 100644 --- a/src/metadata_chatbot/agents/GAMER.py +++ b/src/metadata_chatbot/agents/GAMER.py @@ -10,6 +10,11 @@ from metadata_chatbot.agents.workflow import app from langchain_core.messages import AIMessage, HumanMessage +from streamlit.runtime.scriptrunner import add_script_run_ctx + +from typing import Optional, List, Any, AsyncIterator +from langchain.callbacks.manager import AsyncCallbackManager, CallbackManagerForLLMRun +import streamlit as st @@ -49,13 +54,7 @@ async def _acall( """ Asynchronous call. """ - # unique_id = str(uuid.uuid4()) - # config = {"configurable":{"thread_id": unique_id}} - # inputs = {"query" : query} - # answer = await async_app.ainvoke(inputs) - # return answer['generation'] async def main(query): - #async def main(): unique_id = str(uuid.uuid4()) config = {"configurable":{"thread_id": unique_id}} @@ -79,6 +78,7 @@ async def main(query): def _stream( self, query: str, + unique_id: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, @@ -92,10 +92,23 @@ def _stream( yield chunk - - async def _astream(query): - async def main(query): - #async def main(): + async def streamlit_astream( + self, + query: str, + unique_id: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """ + Asynchronous call. + """ + + config = {"configurable":{"thread_id": unique_id}} + inputs = { + "messages": [HumanMessage(query)], + } + async def main(query: str): unique_id = str(uuid.uuid4()) config = {"configurable":{"thread_id": unique_id}} @@ -106,12 +119,18 @@ async def main(query): for key, value in output.items(): if key != "database_query": yield value['messages'][0].content + else: + yield value['generation'] + curr = None generation = None async for result in main(query): - print(result) + if curr != None: + st.write(curr) + curr = generation generation = result return generation + @property @@ -126,7 +145,7 @@ def _llm_type(self) -> str: """Get the type of language model used by this chat model. Used for logging purposes only.""" return "Claude 3 Sonnet" -llm = GAMER() +# llm = GAMER() # async def main(): # query = "Can you list all the procedures performed on the specimen, including their start and end dates? in SmartSPIM_662616_2023-03-06_17-47-13" diff --git a/src/metadata_chatbot/agents/async_workflow.py b/src/metadata_chatbot/agents/async_workflow.py index ab9ed6a..8499e77 100644 --- a/src/metadata_chatbot/agents/async_workflow.py +++ b/src/metadata_chatbot/agents/async_workflow.py @@ -12,6 +12,8 @@ from metadata_chatbot.agents.react_agent import react_agent from metadata_chatbot.agents.agentic_graph import datasource_router, filter_generation_chain, doc_grader, rag_chain +import streamlit as st + # from docdb_retriever import DocDBRetriever # from react_agent import react_agent # from agentic_graph import datasource_router, filter_generation_chain, doc_grader, rag_chain @@ -67,11 +69,19 @@ async def retrieve_DB_async(state: dict) -> dict: inputs = {"messages": [("user", query)]} try: + prev = None + next = None async for s in react_agent.astream(inputs, stream_mode="values"): message = s["messages"][-1] if message.content != query : + if prev != None: + st.write(prev) state['messages'] = state.get('messages', []) + [message] - print(message.content) # Yield the statement as it's added + prev = next + next = message.content + + # st.write(message.content) + # print(message.content) # Yield the statement as it's added answer = state['messages'][-1].content except: