Skip to content

Commit

Permalink
fix: filter agent_tools for each agent
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Sep 17, 2024
1 parent d7d85d1 commit 8293c58
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from codeinterpreterapi.agents.structured_chat.prompts import create_structured_chat_agent_prompt
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.tools.tools import CodeInterpreterTools


def load_structured_chat_agent_executor(
Expand All @@ -23,9 +24,10 @@ def load_structured_chat_agent_executor(
if ci_params.verbose_prompt:
print("load_structured_chat_agent_executor prompt.input_variables=", input_variables)
print("load_structured_chat_agent_executor prompt=", prompt.messages)
agent_tools = CodeInterpreterTools.get_agent_tools(agent_tools=agent_def.agent_tools, all_tools=ci_params.tools)
agent = create_structured_chat_agent(
llm=ci_params.llm_tools,
tools=ci_params.tools,
tools=agent_tools,
# output_parser=output_parser,
prompt=prompt,
runnable_config=ci_params.runnable_config,
Expand All @@ -34,7 +36,7 @@ def load_structured_chat_agent_executor(
if agent_def:
agent_def.agent = agent

agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=ci_params.tools, verbose=ci_params.verbose)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=agent_tools, verbose=ci_params.verbose)
if agent_def:
agent_def.agent_executor = agent_executor
return agent_executor
Expand Down
6 changes: 4 additions & 2 deletions src/codeinterpreterapi/agents/tool_calling/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from codeinterpreterapi.agents.tool_calling.prompts import create_tool_calling_agent_prompt
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.tools.tools import CodeInterpreterTools


def load_tool_calling_agent_executor(
Expand All @@ -23,17 +24,18 @@ def load_tool_calling_agent_executor(
input_variables = prompt.input_variables
print("load_tool_calling_agent_executor prompt.input_variables=", input_variables)
print("load_tool_calling_agent_executor prompt=", prompt.messages)
agent_tools = CodeInterpreterTools.get_agent_tools(agent_tools=agent_def.agent_tools, all_tools=ci_params.tools)
agent = create_tool_calling_agent(
llm=ci_params.llm_tools,
tools=ci_params.tools,
tools=agent_tools,
# output_parser=output_parser,
prompt=prompt,
runnable_config=ci_params.runnable_config,
)
if agent_def:
agent_def.agent = agent

agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=ci_params.tools, verbose=ci_params.verbose)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=agent_tools, verbose=ci_params.verbose)
if agent_def:
agent_def.agent_executor = agent_executor

Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/crew/crew_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test():
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
_ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
inputs = {"input": TestPrompt.svg_input_str}
plan = CodeInterpreterPlan(agent_name="main_function_create_agent", task_description="", expected_output="")
plan = CodeInterpreterPlan(agent_name="code_write_agent", task_description="", expected_output="")
plan_list = CodeInterpreterPlanList(reliability=80, agent_task_list=[plan, plan])
result = CodeInterpreterCrew(ci_params).run(inputs, plan_list)
print(result)
Expand Down
9 changes: 9 additions & 0 deletions src/codeinterpreterapi/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from codeinterpreterapi.tools.bash import BashTools
from codeinterpreterapi.tools.code_checker import CodeChecker
from codeinterpreterapi.tools.python import PythonTools
from typing import List


class CodeInterpreterTools:
Expand Down Expand Up @@ -40,3 +41,11 @@ def add_tools_shell(self) -> None:
def add_tools_web_search(self) -> None:
tools = [TavilySearchResults(max_results=1)]
self._additional_tools += tools

@staticmethod
def get_agent_tools(agent_tools: str, all_tools: List[BaseTool]) -> None:
selected_tools = []
for tool in all_tools:
if tool.name in agent_tools:
selected_tools.append(tool)
return selected_tools

0 comments on commit 8293c58

Please sign in to comment.