Skip to content

Commit

Permalink
chainlit新增chatflow agent支持 (#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
userpj authored Dec 13, 2024
1 parent 16cd3e6 commit b06664f
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions python/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from appbuilder.core.component import Component
from appbuilder.core.message import Message
from appbuilder.utils.logger_util import logger
from appbuilder.core.console.appbuilder_client.data_class import ToolChoiceFunction, ToolChoice
from appbuilder.core.console.appbuilder_client.data_class import ToolChoiceFunction, ToolChoice, Action

# 流式场景首包超时时,最大重试次数
MAX_RETRY_COUNT = 3
Expand Down Expand Up @@ -197,7 +197,7 @@ def run(self, message: Message, stream: bool=False):
conn.close()
"""

component: Component
user_session_config: Optional[Union[Any, str]] = None
user_session: Optional[Any] = None
Expand Down Expand Up @@ -556,6 +556,7 @@ def chainlit_agent(self, host='0.0.0.0', port=8091):
self.prepare_chainlit_readme()

conversation_ids = []
interrupt_dict = {}

def _chat(message: cl.Message):
if len(conversation_ids) == 0:
Expand All @@ -566,15 +567,39 @@ def _chat(message: cl.Message):
file_id = self.component.upload_local_file(
conversation_id, message.elements[0].path)
file_ids.append(file_id)
return self.component.run(conversation_id=conversation_id, query=message.content, file_ids=file_ids,
stream=True, tool_choice=self.tool_choice)

interrupt_ids = interrupt_dict.get(conversation_id, [])
interrupt_event_id = interrupt_ids.pop() if len(interrupt_ids) > 0 else None
action = None
if interrupt_event_id is not None:
action = Action.create_resume_action(interrupt_event_id)

tmp_message = self.component.run(conversation_id=conversation_id, query=message.content, file_ids=file_ids,
stream=True, tool_choice=self.tool_choice, action=action)
res_message=list(tmp_message.content)

interrupt_event_id = None
for ans in res_message:
for event in ans.events:
if event.content_type == "chatflow_interrupt":
interrupt_event_id = event.detail.get("interrupt_event_id")
if event.content_type == "publish_message" and event.event_type == "chatflow":
answer = event.detail.get("message")
ans.answer += answer

if interrupt_event_id is not None:
interrupt_ids.append(interrupt_event_id)
interrupt_dict[conversation_id] = interrupt_ids
tmp_message.content = res_message
return tmp_message

@cl.on_chat_start
async def start():
session_id = cl.user_session.get("id")
request_id = str(uuid.uuid4())
init_context(session_id=session_id, request_id=request_id)
conversation_ids.append(self.component.create_conversation())
interrupt_dict[conversation_ids[-1]] = []

@cl.on_message # this function will be called every time a user inputs a message in the UI
async def main(message: cl.Message):
Expand Down

0 comments on commit b06664f

Please sign in to comment.