diff --git a/src/codeinterpreterapi/agents/structured_chat/agent_executor.py b/src/codeinterpreterapi/agents/structured_chat/agent_executor.py index 796537c9..4d5201c0 100644 --- a/src/codeinterpreterapi/agents/structured_chat/agent_executor.py +++ b/src/codeinterpreterapi/agents/structured_chat/agent_executor.py @@ -8,6 +8,7 @@ from codeinterpreterapi.agents.structured_chat.prompts import create_structured_chat_agent_prompt from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.llm.llm import prepare_test_llm +from codeinterpreterapi.tools.tools import CodeInterpreterTools def load_structured_chat_agent_executor( @@ -23,9 +24,10 @@ def load_structured_chat_agent_executor( if ci_params.verbose_prompt: print("load_structured_chat_agent_executor prompt.input_variables=", input_variables) print("load_structured_chat_agent_executor prompt=", prompt.messages) + agent_tools = CodeInterpreterTools.get_agent_tools(agent_tools=agent_def.agent_tools, all_tools=ci_params.tools) agent = create_structured_chat_agent( llm=ci_params.llm_tools, - tools=ci_params.tools, + tools=agent_tools, # output_parser=output_parser, prompt=prompt, runnable_config=ci_params.runnable_config, @@ -34,7 +36,7 @@ def load_structured_chat_agent_executor( if agent_def: agent_def.agent = agent - agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=ci_params.tools, verbose=ci_params.verbose) + agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=agent_tools, verbose=ci_params.verbose) if agent_def: agent_def.agent_executor = agent_executor return agent_executor diff --git a/src/codeinterpreterapi/agents/tool_calling/agent_executor.py b/src/codeinterpreterapi/agents/tool_calling/agent_executor.py index 1f0c95fc..d069a97b 100644 --- a/src/codeinterpreterapi/agents/tool_calling/agent_executor.py +++ b/src/codeinterpreterapi/agents/tool_calling/agent_executor.py @@ -8,6 +8,7 @@ from codeinterpreterapi.agents.tool_calling.prompts import create_tool_calling_agent_prompt from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.llm.llm import prepare_test_llm +from codeinterpreterapi.tools.tools import CodeInterpreterTools def load_tool_calling_agent_executor( @@ -23,9 +24,10 @@ def load_tool_calling_agent_executor( input_variables = prompt.input_variables print("load_tool_calling_agent_executor prompt.input_variables=", input_variables) print("load_tool_calling_agent_executor prompt=", prompt.messages) + agent_tools = CodeInterpreterTools.get_agent_tools(agent_tools=agent_def.agent_tools, all_tools=ci_params.tools) agent = create_tool_calling_agent( llm=ci_params.llm_tools, - tools=ci_params.tools, + tools=agent_tools, # output_parser=output_parser, prompt=prompt, runnable_config=ci_params.runnable_config, @@ -33,7 +35,7 @@ def load_tool_calling_agent_executor( if agent_def: agent_def.agent = agent - agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=ci_params.tools, verbose=ci_params.verbose) + agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=agent_tools, verbose=ci_params.verbose) if agent_def: agent_def.agent_executor = agent_executor diff --git a/src/codeinterpreterapi/crew/crew_agent.py b/src/codeinterpreterapi/crew/crew_agent.py index 9460d6eb..264a212e 100644 --- a/src/codeinterpreterapi/crew/crew_agent.py +++ b/src/codeinterpreterapi/crew/crew_agent.py @@ -132,7 +132,7 @@ def test(): ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) _ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params) inputs = {"input": TestPrompt.svg_input_str} - plan = CodeInterpreterPlan(agent_name="main_function_create_agent", task_description="", expected_output="") + plan = CodeInterpreterPlan(agent_name="code_write_agent", task_description="", expected_output="") plan_list = CodeInterpreterPlanList(reliability=80, agent_task_list=[plan, plan]) result = CodeInterpreterCrew(ci_params).run(inputs, plan_list) print(result) diff --git a/src/codeinterpreterapi/tools/tools.py b/src/codeinterpreterapi/tools/tools.py index badde107..759d0544 100644 --- a/src/codeinterpreterapi/tools/tools.py +++ b/src/codeinterpreterapi/tools/tools.py @@ -5,6 +5,7 @@ from codeinterpreterapi.tools.bash import BashTools from codeinterpreterapi.tools.code_checker import CodeChecker from codeinterpreterapi.tools.python import PythonTools +from typing import List class CodeInterpreterTools: @@ -40,3 +41,11 @@ def add_tools_shell(self) -> None: def add_tools_web_search(self) -> None: tools = [TavilySearchResults(max_results=1)] self._additional_tools += tools + + @staticmethod + def get_agent_tools(agent_tools: str, all_tools: List[BaseTool]) -> None: + selected_tools = [] + for tool in all_tools: + if tool.name in agent_tools: + selected_tools.append(tool) + return selected_tools