From 3a94afa7a58f43e014f416c372c534f9b9bb293d Mon Sep 17 00:00:00 2001 From: jinno Date: Thu, 22 Aug 2024 08:13:35 +0900 Subject: [PATCH] fix: generate_response_stream accept BaseMessageContent --- src/codeinterpreterapi/brain/brain.py | 15 ++-- src/codeinterpreterapi/crew/crew_agent.py | 23 ++++-- src/codeinterpreterapi/session.py | 94 +++++++++++++++-------- 3 files changed, 87 insertions(+), 45 deletions(-) diff --git a/src/codeinterpreterapi/brain/brain.py b/src/codeinterpreterapi/brain/brain.py index 8a0d40bb..72482326 100644 --- a/src/codeinterpreterapi/brain/brain.py +++ b/src/codeinterpreterapi/brain/brain.py @@ -97,11 +97,16 @@ def prepare_input(self, input: Input): del input['intermediate_steps'] # set agent_results - input['agent_executor_result'] = self.agent_executor_result - input['plan_list'] = self.plan_list - input['supervisor_result'] = self.supervisor_result - input['thought_result'] = self.thought_result - input['crew_result'] = self.crew_result + if isinstance(input, list): + last_input = input[-1] + else: + last_input = input + + last_input['agent_executor_result'] = self.agent_executor_result + last_input['plan_list'] = self.plan_list + last_input['supervisor_result'] = self.supervisor_result + last_input['thought_result'] = self.thought_result + last_input['crew_result'] = self.crew_result return input def run(self, input: Input, runnable_config: Optional[RunnableConfig] = None) -> Output: diff --git a/src/codeinterpreterapi/crew/crew_agent.py b/src/codeinterpreterapi/crew/crew_agent.py index eae7159f..88315c40 100644 --- a/src/codeinterpreterapi/crew/crew_agent.py +++ b/src/codeinterpreterapi/crew/crew_agent.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Union from crewai import Agent, Crew, Task @@ -69,15 +69,26 @@ def create_task(self, final_goal: str, plan: CodeInterpreterPlan) -> Task: print("WARN: no task found plan.agent_name=", plan.agent_name) return None - def run(self, inputs: Dict, plan_list: CodeInterpreterPlanList): + def run(self, inputs: Union[Dict, List[Dict]], plan_list: CodeInterpreterPlanList): # update task description if plan_list is None: return {} - tasks = self.create_tasks(final_goal=inputs["input"], plan_list=plan_list) - crew_inputs = {"input": inputs.get("input", "")} + + if isinstance(inputs, list): + last_input = inputs[-1] + else: + last_input = inputs + if "input" in inputs: + final_goal = last_input["input"] + elif "content" in inputs: + final_goal = last_input["content"] + else: + final_goal = "ユーザの指示に従って最終的な回答をしてください" + + tasks = self.create_tasks(final_goal=final_goal, plan_list=plan_list) my_crew = Crew(agents=self.agents, tasks=tasks) - print("CodeInterpreterCrew.kickoff() crew_inputs=", crew_inputs) - result = my_crew.kickoff(inputs=crew_inputs) + print("CodeInterpreterCrew.kickoff() crew_inputs=", last_input) + result = my_crew.kickoff(inputs=last_input) return result diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index 85f6947b..0f9e24b0 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -19,7 +19,6 @@ from codeinterpreterapi.brain.brain import CodeInterpreterBrain from codeinterpreterapi.brain.params import CodeInterpreterParams -from codeinterpreterapi.callbacks.markdown.callbacks import MarkdownFileCallbackHandler from codeinterpreterapi.chains import aremove_download_link from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory from codeinterpreterapi.config import settings @@ -27,6 +26,8 @@ from codeinterpreterapi.schema import CodeInterpreterResponse, File, SessionStatus, UserRequest from codeinterpreterapi.utils.multi_converter import MultiConverter +BaseMessageContent = Union[str, List[Union[str, Dict]]] + def _handle_deprecated_kwargs(kwargs: dict) -> None: settings.MODEL = kwargs.get("model", settings.MODEL) @@ -105,7 +106,8 @@ def __init__( configurable = {"session_id": init_session_id} # TODO: set session_id runnable_config = RunnableConfig( configurable=configurable, - callbacks=[AgentCallbackHandler(self._output_handler), MarkdownFileCallbackHandler("langchain_log.md")], + callbacks=[], + # callbacks=[AgentCallbackHandler(self._output_handler), MarkdownFileCallbackHandler("langchain_log.md")], ) # ci_params = {} @@ -199,34 +201,61 @@ def _history_backend(self) -> BaseChatMessageHistory: ) ) - def _input_handler(self, request: UserRequest) -> None: + def _input_message_prepare(self, request: UserRequest) -> None: + # set return input_message + if isinstance(request.content, str): + input_message = {"input": request.content, "agent_scratchpad": ""} + else: + input_message = request.content + return input_message + + def _input_handler_common(self, request: UserRequest, add_content_str: str) -> None: """Callback function to handle user input.""" + # TODO: variables as context to the agent + # TODO: current files as context to the agent if not request.files: - return + return self._input_message_prepare(request) if not request.content: - request.content = "I uploaded, just text me back and confirm that you got the file(s)." - assert isinstance(request.content, str), "TODO: implement image support" - request.content += "\n**The user uploaded the following files: **\n" + add_content_str = "I uploaded, just text me back and confirm that you got the file(s).\n" + add_content_str + assert isinstance(request.content, BaseMessageContent) # "TODO: implement image support" + + # set request.content + if isinstance(request.content, str): + request.content += add_content_str + elif isinstance(request.content, list): + last_content = request.content[-1] + if isinstance(last_content, str): + last_content += add_content_str + else: + pass + # TODO impl it. + # last_content["input"] += add_content_str + else: + # Dict + pass + # TODO impl it. + # request.content["input"] += add_content_str + return self._input_message_prepare(request) + + def _input_handler(self, request: UserRequest) -> None: + """Callback function to handle user input.""" + add_content_str = "\n**The user uploaded the following files: **\n" for file in request.files: self.input_files.append(file) - request.content += f"[Attachment: {file.name}]\n" + add_content_str += f"[Attachment: {file.name}]\n" self.ci_params.codebox.upload(file.name, file.content) - request.content += "**File(s) are now available in the cwd. **\n" + add_content_str += "**File(s) are now available in the cwd. **\n" + return self._input_handler_common(request, add_content_str) async def _ainput_handler(self, request: UserRequest) -> None: - # TODO: variables as context to the agent - # TODO: current files as context to the agent - if not request.files: - return - if not request.content: - request.content = "I uploaded, just text me back and confirm that you got the file(s)." - assert isinstance(request.content, str), "TODO: implement image support" - request.content += "\n**The user uploaded the following files: **\n" + """Callback function to handle user input.""" + add_content_str = "\n**The user uploaded the following files: **\n" for file in request.files: self.input_files.append(file) - request.content += f"[Attachment: {file.name}]\n" + add_content_str += f"[Attachment: {file.name}]\n" await self.ci_params.codebox.aupload(file.name, file.content) - request.content += "**File(s) are now available in the cwd. **\n" + add_content_str += "**File(s) are now available in the cwd. **\n" + return self._input_handler_common(request, add_content_str) def _output_handler_pre(self, response: Any) -> str: print("_output_handler_pre response(type)=", type(response)) @@ -304,7 +333,7 @@ async def _aoutput_handler(self, response: str) -> CodeInterpreterResponse: def generate_response_sync( self, - user_msg: str, + user_msg: BaseMessageContent, files: list[File] = [], ) -> CodeInterpreterResponse: print("DEPRECATION WARNING: Use generate_response for sync generation.\n") @@ -315,16 +344,16 @@ def generate_response_sync( def generate_response( self, - user_msg: str, - files: list[File] = [], + user_msg: BaseMessageContent, + files: list[File] = None, ) -> CodeInterpreterResponse: """Generate a Code Interpreter response based on the user's input.""" + if files is None: + files = [] user_request = UserRequest(content=user_msg, files=files) try: - self._input_handler(user_request) + input_message = self._input_handler(user_request) print("generate_response type(user_request.content)=", type(user_request.content)) - agent_scratchpad = "" - input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad} # ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #======= response = self.brain.invoke(input=input_message) # ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #======= @@ -348,15 +377,15 @@ def generate_response( async def agenerate_response( self, - user_msg: str, + user_msg: BaseMessageContent, files: list[File] = None, ) -> CodeInterpreterResponse: """Generate a Code Interpreter response based on the user's input.""" if files is None: - files = None + files = [] user_request = UserRequest(content=user_msg, files=files) try: - await self._ainput_handler(user_request) + input_message = await self._ainput_handler(user_request) agent_scratchpad = "" input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad} # ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #======= @@ -380,7 +409,7 @@ async def agenerate_response( def generate_response_stream( self, - user_msg: str, + user_msg: BaseMessageContent, files: list[File] = None, ) -> Iterator[str]: """Generate a Code Interpreter response based on the user's input.""" @@ -388,10 +417,7 @@ def generate_response_stream( files = [] user_request = UserRequest(content=user_msg, files=files) try: - self._input_handler(user_request) - agent_scratchpad = "" - input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad} - print("generate_response_stream type(user_request.content)=", user_request.content) + input_message = self._input_handler(user_request) # ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #======= response_stream = self.brain.stream(input=input_message) # ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #======= @@ -417,7 +443,7 @@ def generate_response_stream( async def agenerate_response_stream( self, - user_msg: str, + user_msg: BaseMessageContent, files: list[File] = None, ) -> AsyncGenerator[CodeInterpreterResponse, None]: """Generate a Code Interpreter response based on the user's input."""