From 004494611029805dce1f433ceffc37fd710af444 Mon Sep 17 00:00:00 2001 From: jinno Date: Sat, 24 Aug 2024 15:12:19 +0900 Subject: [PATCH] fix: add thought for GuiAgentInterpreterChatResponse --- src/codeinterpreterapi/crew/custom_agent.py | 2 +- src/codeinterpreterapi/schema.py | 3 ++- src/codeinterpreterapi/session.py | 2 +- .../utils/multi_converter.py | 19 +++++++++++++++++-- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/codeinterpreterapi/crew/custom_agent.py b/src/codeinterpreterapi/crew/custom_agent.py index b8fa932d..08fb7d5c 100644 --- a/src/codeinterpreterapi/crew/custom_agent.py +++ b/src/codeinterpreterapi/crew/custom_agent.py @@ -27,7 +27,6 @@ def __init__(self, agent_executor: Any, **data): def interpolate_inputs(self, inputs: Dict[str, Any]) -> None: """Interpolate inputs into the task description and expected output.""" - print("interpolate_inputs inputs=", inputs) super().interpolate_inputs(inputs) def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional[List[Any]] = None) -> str: @@ -45,6 +44,7 @@ def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional input_dict["message"] = "タスクを実行してください。\n" + task.expected_output result = self.agent_executor.invoke(input=input_dict) result_str = MultiConverter.to_str(result) + print("execute_task result(type)=", type(result_str)) print("execute_task result=", result_str) # TODO: return full dict when crewai is updated diff --git a/src/codeinterpreterapi/schema.py b/src/codeinterpreterapi/schema.py index 4dd3638b..7a754931 100644 --- a/src/codeinterpreterapi/schema.py +++ b/src/codeinterpreterapi/schema.py @@ -116,7 +116,8 @@ class CodeInterpreterResponse(AIMessage): files: Optional[list[File]] = [] code_log: Optional[dict[str, str]] = [] - agent_name: str = "" + agent_name: Optional[str] = "" + thought: Optional[str] = "" # 中間的な思考 def show(self) -> None: print("AI: ", self.content) diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index 0f9e24b0..23dc6d94 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -411,7 +411,7 @@ def generate_response_stream( self, user_msg: BaseMessageContent, files: list[File] = None, - ) -> Iterator[str]: + ) -> Iterator[CodeInterpreterResponse]: """Generate a Code Interpreter response based on the user's input.""" if files is None: files = [] diff --git a/src/codeinterpreterapi/utils/multi_converter.py b/src/codeinterpreterapi/utils/multi_converter.py index f2f0650b..a67e9b09 100644 --- a/src/codeinterpreterapi/utils/multi_converter.py +++ b/src/codeinterpreterapi/utils/multi_converter.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List +from crewai.crews.crew_output import CrewOutput, TaskOutput from langchain_core.messages import AIMessageChunk @@ -14,13 +15,16 @@ def to_str(input_obj: Any) -> str: if len(input_obj) > 0: input_obj = MultiConverter._process_dict(input_obj[-1]) else: - input_obj = "" + return "no output" elif isinstance(input_obj, Dict): input_obj = MultiConverter._process_dict(input_obj) + elif isinstance(input_obj, CrewOutput): + input_obj = MultiConverter._process_crew_output(input_obj) else: + print("MultiConverter to_str type(input_obj)=", type(input_obj)) return str(input_obj) - # 再帰 + # 確実にstr以外は念のため再帰 return MultiConverter.to_str(input_obj) @staticmethod @@ -43,3 +47,14 @@ def _process_dict(input_dict: Dict[str, Any]) -> str: keys = ["tool", "tool_input_obj", "log"] code_log_item = {key: str(input_dict[key]) for key in keys if key in input_dict} return str(code_log_item) if code_log_item else str(input_dict) + + @staticmethod + def _process_crew_output(input_crew_output: CrewOutput) -> str: + # TODO: return json or + last_task_output: TaskOutput = input_crew_output.tasks_output[-1] + if last_task_output.json_dict: + return str(last_task_output.json_dict) + elif last_task_output.pydantic: + return str(last_task_output.pydantic) + else: + return last_task_output.raw