From 2ba99a3b2d6a8d61c9c4d03f38b44bda76b00506 Mon Sep 17 00:00:00 2001 From: jinno Date: Sat, 24 Aug 2024 20:18:12 +0900 Subject: [PATCH] fix: add CodeInterpreterIntermediateResult for schema --- src/codeinterpreterapi/crew/crew_agent.py | 10 +++++-- src/codeinterpreterapi/schema.py | 27 +++++++++++++---- src/codeinterpreterapi/session.py | 24 ++++++++++----- .../supervisors/supervisors.py | 30 +++++++++++-------- 4 files changed, 64 insertions(+), 27 deletions(-) diff --git a/src/codeinterpreterapi/crew/crew_agent.py b/src/codeinterpreterapi/crew/crew_agent.py index 88315c40..f2451775 100644 --- a/src/codeinterpreterapi/crew/crew_agent.py +++ b/src/codeinterpreterapi/crew/crew_agent.py @@ -1,6 +1,7 @@ from typing import Dict, List, Union from crewai import Agent, Crew, Task +from crewai.crews.crew_output import CrewOutput from codeinterpreterapi.agents.agents import CodeInterpreterAgent from codeinterpreterapi.brain.params import CodeInterpreterParams @@ -9,7 +10,7 @@ ) from codeinterpreterapi.graphs.agent_wrapper_tool import AgentWrapperTool from codeinterpreterapi.llm.llm import prepare_test_llm -from codeinterpreterapi.schema import CodeInterpreterPlan, CodeInterpreterPlanList +from codeinterpreterapi.schema import CodeInterpreterIntermediateResult, CodeInterpreterPlan, CodeInterpreterPlanList from codeinterpreterapi.test_prompts.test_prompt import TestPrompt @@ -69,7 +70,9 @@ def create_task(self, final_goal: str, plan: CodeInterpreterPlan) -> Task: print("WARN: no task found plan.agent_name=", plan.agent_name) return None - def run(self, inputs: Union[Dict, List[Dict]], plan_list: CodeInterpreterPlanList): + def run( + self, inputs: Union[Dict, List[Dict]], plan_list: CodeInterpreterPlanList + ) -> CodeInterpreterIntermediateResult: # update task description if plan_list is None: return {} @@ -88,7 +91,8 @@ def run(self, inputs: Union[Dict, List[Dict]], plan_list: CodeInterpreterPlanLis tasks = self.create_tasks(final_goal=final_goal, plan_list=plan_list) my_crew = Crew(agents=self.agents, tasks=tasks) print("CodeInterpreterCrew.kickoff() crew_inputs=", last_input) - result = my_crew.kickoff(inputs=last_input) + crew_result: CrewOutput = my_crew.kickoff(inputs=last_input) + result = CodeInterpreterIntermediateResult(context=crew_result.raw) return result diff --git a/src/codeinterpreterapi/schema.py b/src/codeinterpreterapi/schema.py index 7a754931..f3f4a756 100644 --- a/src/codeinterpreterapi/schema.py +++ b/src/codeinterpreterapi/schema.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from codeboxapi.schema import CodeBoxStatus from langchain_core.messages import AIMessage, HumanMessage @@ -106,16 +106,33 @@ def __repr__(self) -> str: return f"UserRequest(content={self.content}, files={self.files})" +class CodeInterpreterIntermediateResult(BaseModel): + thoughts: List[str] = Field( + default_factory=list, + description="エージェントの思考プロセスを表す文字列のリスト(最新の思考および根拠を理解するために必要な情報のみが入っている)", + ) + context: str = Field(description="llmやagentからの回答本文") + code: str = Field(default="", description="プログラムのソースコード") + log: str = Field(default="", description="コードの実行結果やテスト結果などのlog") + language: str = Field(default="", description="llmやagentからの回答本文") + confidence: float = Field(default=0.95, description="現在の回答の信頼度[0.0~1.0], 1.0が最も信頼できる") + target_confidence: float = Field(default=0.95, description="目標とする信頼度") + metadata: Dict[str, Any] = Field(default_factory=dict, description="追加のメタデータを格納する辞書") + iteration_count: int = Field(default=0, description="現在の反復回数") + max_iterations: int = Field(default=10, description="最大反復回数") + + class CodeInterpreterResponse(AIMessage): """ Response from the code interpreter agent. - - files: list of files to be sent to the user (File ) - code_log: list[tuple[str, str]] = [] """ files: Optional[list[File]] = [] - code_log: Optional[dict[str, str]] = [] + code: str = "" # final code + log: str = "" # final log + language: str = "" # ex: python, java + start: bool = False + end: bool = False agent_name: Optional[str] = "" thought: Optional[str] = "" # 中間的な思考 diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index 23dc6d94..c4281e6a 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -131,7 +131,7 @@ def __init__( self.input_files: list[File] = [] self.output_files: list[File] = [] - self.output_code_log: list[tuple[str, str]] = [] + self.output_code_log_list: list[tuple[str, str]] = [] @classmethod def from_id(cls, session_id: UUID, **kwargs: Any) -> "CodeInterpreterSession": @@ -273,7 +273,7 @@ def _output_handler_pre(self, response: Any) -> str: code_log_item["tool_input"] = str(response["tool_input"]) if "log" in response: code_log_item["log"] = str(response["log"]) - self.output_code_log = code_log_item + self.output_code_log_list = code_log_item return output_str def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse: @@ -292,13 +292,23 @@ def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse: # print("Error while removing download links:", e) output_files = self.output_files - code_log = self.output_code_log + code_log = self.output_code_log_list + final_code = "" + final_log = "" + if len(final_code) > 0: + final_code = code_log[-1][0] + final_log = code_log[-1][0] self.output_files = [] - self.output_code_log = [] + self.output_code_log_list = [] print("_output_handler self.brain.current_agent=", self.brain.current_agent) response = CodeInterpreterResponse( - content=final_response, files=output_files, code_log=code_log, agent_name=self.brain.current_agent + content=final_response, + files=output_files, + code_log=code_log, + agent_name=self.brain.current_agent, + code=final_code, + log=final_log, ) return response @@ -324,9 +334,9 @@ async def _aoutput_handler(self, response: str) -> CodeInterpreterResponse: print("Error while removing download links:", e) output_files = self.output_files - code_log = self.output_code_log + code_log = self.output_code_log_list self.output_files = [] - self.output_code_log = [] + self.output_code_log_list = [] response = CodeInterpreterResponse(content=final_response, files=output_files, code_log=code_log) return response diff --git a/src/codeinterpreterapi/supervisors/supervisors.py b/src/codeinterpreterapi/supervisors/supervisors.py index 38252124..99c1b910 100644 --- a/src/codeinterpreterapi/supervisors/supervisors.py +++ b/src/codeinterpreterapi/supervisors/supervisors.py @@ -6,16 +6,17 @@ from langchain.agents import AgentExecutor from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import Runnable -from langchain_core.runnables.utils import Input, Output +from langchain_core.runnables.utils import Input from codeinterpreterapi.agents.agents import CodeInterpreterAgent from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.crew.crew_agent import CodeInterpreterCrew from codeinterpreterapi.llm.llm import prepare_test_llm from codeinterpreterapi.planners.planners import CodeInterpreterPlanner -from codeinterpreterapi.schema import CodeInterpreterPlanList +from codeinterpreterapi.schema import CodeInterpreterIntermediateResult, CodeInterpreterPlanList from codeinterpreterapi.supervisors.prompts import create_supervisor_agent_prompt from codeinterpreterapi.test_prompts.test_prompt import TestPrompt +from codeinterpreterapi.utils.multi_converter import MultiConverter class CodeInterpreterSupervisor: @@ -81,22 +82,27 @@ def get_executor(self) -> AgentExecutor: # TODO: impl return self.supervisor_chain - def invoke(self, input: Input) -> Output: - result = self.planner.invoke(input, config=self.ci_params.runnable_config) - print("supervisor.invoke type(result)=", type(result)) - if isinstance(result, CodeInterpreterPlanList): - plan_list: CodeInterpreterPlanList = result + def invoke(self, input: Input) -> CodeInterpreterIntermediateResult: + planner_result = self.planner.invoke(input, config=self.ci_params.runnable_config) + print("supervisor.invoke type(planner_result)=", type(planner_result)) + if isinstance(planner_result, CodeInterpreterPlanList): + plan_list: CodeInterpreterPlanList = planner_result if len(plan_list.agent_task_list) > 0: print("supervisor.invoke use crew_agent plan_list=", plan_list) - result = self.ci_params.crew_agent.run(input, result) + result: CodeInterpreterIntermediateResult = self.ci_params.crew_agent.run(input, plan_list) else: - print("supervisor.invoke no_agent plan_list=", plan_list) - result = self.supervisor_chain_no_agent.invoke(input) + print("supervisor.invoke empty plan_list") + result_dict = self.supervisor_chain_no_agent.invoke(input) + result_str = MultiConverter.to_str(result_dict) + result = CodeInterpreterIntermediateResult(content=result_str) else: print("supervisor.invoke no_agent no plan_list") - result = self.supervisor_chain_no_agent.invoke(input) + result_dict = self.supervisor_chain_no_agent.invoke(input) + result_str = MultiConverter.to_str(result_dict) + result = CodeInterpreterIntermediateResult(content=result_str) return result + # NOT USED def execute_plan(self, plan_list: CodeInterpreterPlanList) -> Dict[str, Any]: print("supervisor.execute_plan type(plan_list)=", type(plan_list)) # AgentExecutorの初期化 @@ -115,7 +121,7 @@ def execute_plan(self, plan_list: CodeInterpreterPlanList) -> Dict[str, Any]: try: # タスクを実行 - result = agent.run( + result: CodeInterpreterIntermediateResult = agent.run( f"Task: {plan.task_description}\n" f"Expected output: {plan.expected_output}\n" "Please complete this task."