From 390800d41f420af17266a6a26c5fbc78af9dfc79 Mon Sep 17 00:00:00 2001 From: jinno Date: Tue, 13 Aug 2024 12:43:23 +0900 Subject: [PATCH] fix: change agent tool --- .../agents/config/coding_agent/config.yaml | 2 +- .../agents/config/split_agent/config.yaml | 2 +- .../agents/structured_chat/agent.py | 2 +- src/codeinterpreterapi/crew/custom_agent.py | 14 +++++++++++--- src/codeinterpreterapi/session.py | 11 ++++++----- src/codeinterpreterapi/tools/code_checker.py | 10 +++++----- 6 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/codeinterpreterapi/agents/config/coding_agent/config.yaml b/src/codeinterpreterapi/agents/config/coding_agent/config.yaml index fcb8f94b..ba477d40 100644 --- a/src/codeinterpreterapi/agents/config/coding_agent/config.yaml +++ b/src/codeinterpreterapi/agents/config/coding_agent/config.yaml @@ -1,6 +1,6 @@ agent_definition: agent_name: "main_function_create_agent" - agent_type: structured_chat + agent_type: tool_calling agent_role: | pythonプログラムのmain関数を作成する作業をお願いします。 別のagentはプログラムの分割やテストの作成を実施することになります。 diff --git a/src/codeinterpreterapi/agents/config/split_agent/config.yaml b/src/codeinterpreterapi/agents/config/split_agent/config.yaml index 1f808fae..cf902fd4 100644 --- a/src/codeinterpreterapi/agents/config/split_agent/config.yaml +++ b/src/codeinterpreterapi/agents/config/split_agent/config.yaml @@ -1,6 +1,6 @@ agent_definition: agent_name: "code_split_agent" - agent_type: structured_chat + agent_type: tool_calling agent_role: | プログラムを分割する作業をお願いします。 別のagentはプログラムの作成やテストの作成を実施することになります。 diff --git a/src/codeinterpreterapi/agents/structured_chat/agent.py b/src/codeinterpreterapi/agents/structured_chat/agent.py index bb12131c..d488d3a4 100644 --- a/src/codeinterpreterapi/agents/structured_chat/agent.py +++ b/src/codeinterpreterapi/agents/structured_chat/agent.py @@ -172,7 +172,7 @@ def create_structured_chat_agent( | llm_with_stop ) agent = assign_runnable_history(agent, runnable_config) - agent = agent | output_parser + # agent = agent | output_parser return agent diff --git a/src/codeinterpreterapi/crew/custom_agent.py b/src/codeinterpreterapi/crew/custom_agent.py index afb1180c..b8fa932d 100644 --- a/src/codeinterpreterapi/crew/custom_agent.py +++ b/src/codeinterpreterapi/crew/custom_agent.py @@ -4,6 +4,8 @@ from langchain_core.tools import BaseTool from pydantic import Field +from codeinterpreterapi.utils.multi_converter import MultiConverter + class CustomAgent(BaseAgent): agent_executor: Any = Field(default=None, description="Verbose mode for the Agent Execution") @@ -28,7 +30,11 @@ def interpolate_inputs(self, inputs: Dict[str, Any]) -> None: print("interpolate_inputs inputs=", inputs) super().interpolate_inputs(inputs) - def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional[List[Any]] = None): + def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional[List[Any]] = None) -> str: + # Notice: ValidationError - 1 validation error for TaskOutput | raw: Input should be a valid string + # crewaiのTaskOutputのrawに入るのでstrで返す必要がある。 + # TODO: 直接dictを返せるようにcrewaiを直す? + # AgentExecutorを使用してタスクを実行 # print("execute_task task=", task) print("execute_task context=", context) @@ -38,9 +44,11 @@ def execute_task(self, task: Any, context: Optional[str] = None, tools: Optional input_dict["question"] = task.prompt_context input_dict["message"] = "タスクを実行してください。\n" + task.expected_output result = self.agent_executor.invoke(input=input_dict) - print("execute_task result=", result) + result_str = MultiConverter.to_str(result) + print("execute_task result=", result_str) + # TODO: return full dict when crewai is updated - return result["output"] + return result_str def create_agent_executor(self, tools=None) -> None: pass diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index ea0ee30f..abc8d2b9 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -25,6 +25,7 @@ from codeinterpreterapi.config import settings from codeinterpreterapi.llm.llm import CodeInterpreterLlm from codeinterpreterapi.schema import CodeInterpreterResponse, File, SessionStatus, UserRequest +from codeinterpreterapi.utils.multi_converter import MultiConverter def _handle_deprecated_kwargs(kwargs: dict) -> None: @@ -49,6 +50,7 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **k def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: """Run when chain ends running.""" print("AgentCallbackHandler on_chain_end type(outputs)=", type(outputs)) + print("AgentCallbackHandler on_chain_end type(outputs)=", outputs) self.agent_callback_func(outputs) def on_chat_model_start( @@ -228,9 +230,10 @@ async def _ainput_handler(self, request: UserRequest) -> None: def _output_handler_pre(self, response: Any) -> str: print("_output_handler_pre response(type)=", type(response)) - if isinstance(response, str): - output_str = response - elif isinstance(response, dict): + output_str = MultiConverter.to_str(response) + + # TODO: MultiConverterに共通化 + if isinstance(response, dict): output_str = "" code_log_item = {} if "output" in response: @@ -242,8 +245,6 @@ def _output_handler_pre(self, response: Any) -> str: if "log" in response: code_log_item["log"] = str(response["log"]) self.output_code_log = code_log_item - else: - output_str = "response=" + str(response) return output_str def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse: diff --git a/src/codeinterpreterapi/tools/code_checker.py b/src/codeinterpreterapi/tools/code_checker.py index dc42e70c..664f14c6 100644 --- a/src/codeinterpreterapi/tools/code_checker.py +++ b/src/codeinterpreterapi/tools/code_checker.py @@ -5,7 +5,7 @@ from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.config import settings from codeinterpreterapi.llm.llm import prepare_test_llm -from codeinterpreterapi.schema import CodeInput +from codeinterpreterapi.schema import FileInput from codeinterpreterapi.utils.file_util import FileUtil @@ -27,13 +27,13 @@ def get_tools_code_checker(cls, ci_params: CodeInterpreterParams) -> None: "Start from this latest code first.\n", func=tools_instance._get_latest_code, coroutine=tools_instance._aget_latest_code, - args_schema=CodeInput, # type: ignore + args_schema=FileInput, # type: ignore ), ] return tools - def _get__get_latest_code_common(self, filename=""): + def _get_latest_code_common(self, filename=""): target_dir = "./" target_filename = "main.py" @@ -51,10 +51,10 @@ def _get__get_latest_code_common(self, filename=""): return "" def _get_latest_code(self, filename=""): - return self._get__get_latest_code_common(filename) + return self._get_latest_code_common(filename) async def _aget_latest_code(self, filename=""): - return self._get__get_latest_code_common(filename) + return self._get_latest_code_common(filename) def test():