Skip to content

Commit

Permalink
Output handler max iterations (#61)
Browse files Browse the repository at this point in the history
* add max_iterations parameter for output_handler

* add tests for output handler max iterations

* add max_iterations params for CoderOutputHandler in Advanced output handling.ipynb

* add MockTool for agent tests

* replace DuckDuckGoSearchRun to MockTool

* update OutputHandlerMaxIterationsExceeded

* minor fixes

* fix tests

---------

Co-authored-by: User <[email protected]>
Co-authored-by: whimo <[email protected]>
  • Loading branch information
3 people authored Jul 7, 2024
1 parent 40cbf99 commit a4c2fdd
Show file tree
Hide file tree
Showing 43 changed files with 173 additions and 60 deletions.
2 changes: 1 addition & 1 deletion examples/Advanced output handling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"coder = ReActToolCallingAgent(\n",
" name=\"coder\",\n",
" tools=[PythonREPLTool()],\n",
" output_handler=CoderOutputHandler(),\n",
" output_handler=CoderOutputHandler(max_iterations=3),\n",
" verbose=True,\n",
")\n",
"\n",
Expand Down
12 changes: 11 additions & 1 deletion motleycrew/agents/output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent
from motleycrew.common.exceptions import InvalidOutput
from motleycrew.common import Defaults
from motleycrew.tools import MotleyTool


Expand All @@ -22,7 +23,16 @@ class MotleyOutputHandler(MotleyTool, ABC):
_exceptions_to_handle: tuple[Exception] = (InvalidOutput,)
"""Exceptions that should be returned to the agent when raised in the `handle_output` method."""

def __init__(self):
def __init__(self, max_iterations: int = Defaults.DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS):
"""Initialize the output handler tool.
Args:
max_iterations (int): Maximum number of iterations to run the output handler.
If an exception is raised in the `handle_output` method, the output handler will return
the exception to the agent unless the number of iterations exceeds `max_iterations`,
in which case the output handler will raise OutputHandlerMaxIterationsExceeded.
"""
self.max_iterations = max_iterations
langchain_tool = self._create_langchain_tool()
super().__init__(langchain_tool)

Expand Down
19 changes: 17 additions & 2 deletions motleycrew/agents/parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent
from motleycrew.common import MotleyAgentFactory, MotleySupportedTool
from motleycrew.common import logger
from motleycrew.common import logger, Defaults
from motleycrew.common.exceptions import (
AgentNotMaterialized,
CannotModifyMaterializedAgent,
InvalidOutput,
OutputHandlerMaxIterationsExceeded,
)
from motleycrew.tools import MotleyTool

Expand Down Expand Up @@ -131,18 +132,32 @@ def _prepare_output_handler(self) -> Optional[MotleyTool]:
if isinstance(self.output_handler, MotleyOutputHandler):
exceptions_to_handle = self.output_handler.exceptions_to_handle
description = self.output_handler.description
max_iterations = self.output_handler.max_iterations

else:
exceptions_to_handle = (InvalidOutput,)
description = self.output_handler.description or f"Output handler"
assert isinstance(description, str)
description += "\n ONLY RETURN THE FINAL RESULT USING THIS TOOL!"
max_iterations = Defaults.DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS

iteration = 0

def handle_agent_output(*args, **kwargs):
assert self.output_handler
nonlocal iteration

try:
iteration += 1
output = self.output_handler._run(*args, **kwargs)
except exceptions_to_handle as exc:
return f"{exc.__class__.__name__}: {str(exc)}"
if iteration <= max_iterations:
return f"{exc.__class__.__name__}: {str(exc)}"
raise OutputHandlerMaxIterationsExceeded(
last_call_args=args,
last_call_kwargs=kwargs,
last_exception=exc,
)

raise DirectOutput(output)

Expand Down
6 changes: 5 additions & 1 deletion motleycrew/common/defaults.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
""" Module description """

from motleycrew.common import LLMFamily
from motleycrew.common import GraphStoreType


class Defaults:
""" Description
"""Description
Attributes:
DEFAULT_LLM_FAMILY (str):
Expand All @@ -15,8 +16,10 @@ class Defaults:
MODULE_INSTALL_COMMANDS (dict):
DEFAULT_NUM_THREADS (int):
DEFAULT_EVENT_LOOP_SLEEP (int):
DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS (int):
"""

DEFAULT_LLM_FAMILY = LLMFamily.OPENAI
DEFAULT_LLM_NAME = "gpt-4o"
DEFAULT_LLM_TEMPERATURE = 0.0
Expand All @@ -35,3 +38,4 @@ class Defaults:

DEFAULT_NUM_THREADS = 4
DEFAULT_EVENT_LOOP_SLEEP = 1
DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS = 5
19 changes: 18 additions & 1 deletion motleycrew/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" Module description"""

from typing import Any, Optional
from typing import Any, Dict, Optional

from motleycrew.common import Defaults

Expand Down Expand Up @@ -142,3 +142,20 @@ class InvalidOutput(Exception):
"""Raised in output handlers when an agent's output is not accepted"""

pass


class OutputHandlerMaxIterationsExceeded(BaseException):
"""Raised when the output handlers iteration limit is exceeded"""

def __init__(
self,
last_call_args: tuple,
last_call_kwargs: Dict[str, Any],
last_exception: Exception,
):
self.last_call_args = last_call_args
self.last_call_kwargs = last_call_kwargs
self.last_exception = last_exception

def __str__(self):
return "Maximum number of output handler iterations exceeded"
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n swapped = False\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n swapped = True\n if not swapped:\n break\n return arr\n\n# Test the bubble sort function\nsample_array = [64, 34, 25, 12, 22, 11, 90]\nsorted_array = bubble_sort(sample_array)\nprint(sorted_array)\n\nThe `bubble_sort` function sorts an array using the bubble sort algorithm. It works by repeatedly stepping through the list, comparing adjacent elements and swapping them if they are in the wrong order. This process is repeated until the list is sorted. The outer loop runs `n` times, where `n` is the length of the array, and the inner loop runs `n-i-1` times to avoid re-checking the already sorted elements. An optimization is added by using a `swapped` flag to detect if any swaps were made during an iteration. If no swaps were made, the array is already sorted, and the algorithm can terminate early. The test case demonstrates the function by sorting a sample array."
"def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n swapped = False\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n swapped = True\n if not swapped:\n break\n return arr\n\n# Test the bubble sort function\nsample_array = [64, 34, 25, 12, 22, 11, 90]\nsorted_array = bubble_sort(sample_array)\nprint(sorted_array)\n\nThe `bubble_sort` function sorts an array using the bubble sort algorithm. It works by repeatedly stepping through the list, comparing adjacent elements and swapping them if they are in the wrong order. This process is repeated until the list is sorted. The outer loop runs `n` times, where `n` is the length of the array. The inner loop runs `n-i-1` times to avoid re-checking the already sorted elements. An optimization is added by using a `swapped` flag to detect if any swaps were made during an iteration. If no swaps were made, the list is already sorted, and the algorithm can terminate early. The test case demonstrates the function by sorting a sample array."
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"\\[\n\\begin{aligned}\nx &= \\frac{367}{71} \\\\\ny &= -\\frac{25}{49} \\\\\nx - y &= 2\n\\end{aligned}\n\\]"
"Agent stopped due to iteration limit or time limit."
21 changes: 21 additions & 0 deletions tests/test_agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Type
from langchain_core.tools import BaseTool
from langchain_core.pydantic_v1 import BaseModel, Field


class MockToolInput(BaseModel):
"""Input for the MockTool tool."""

tool_input: str = Field(description="tool_input")


class MockTool(BaseTool):
"""Mock tool for run agent tests"""

name: str = "mock tool"
description: str = "Mock tool for tests"

args_schema: Type[BaseModel] = MockToolInput

def _run(self, tool_input: str, *args, **kwargs):
return tool_input
10 changes: 5 additions & 5 deletions tests/test_agents/test_agents.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
import pytest

from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.prompts.chat import ChatPromptTemplate
from motleycrew.agents.crewai.crewai_agent import CrewAIMotleyAgent
from motleycrew.agents.langchain.tool_calling_react import ReActToolCallingAgent
from motleycrew.agents.llama_index.llama_index_react import ReActLlamaIndexMotleyAgent
from motleycrew.common.exceptions import AgentNotMaterialized, CannotModifyMaterializedAgent
from motleycrew.tools.python_repl import create_repl_tool
from motleycrew.tools.tool import MotleyTool
from tests.test_agents import MockTool

os.environ["OPENAI_API_KEY"] = "YOUR OPENAI API KEY"

Expand All @@ -28,7 +28,7 @@ def crewai_agent(self):
backstory="",
verbose=True,
delegation=False,
tools=[DuckDuckGoSearchRun()],
tools=[MockTool()],
)
return agent

Expand All @@ -38,7 +38,7 @@ def langchain_agent(self):
name="AI writer agent",
prompt_prefix="Generate AI-generated content",
description="AI-generated content",
tools=[DuckDuckGoSearchRun()],
tools=[MockTool()],
verbose=True,
)
return agent
Expand All @@ -48,7 +48,7 @@ def llama_index_agent(self):
agent = ReActLlamaIndexMotleyAgent(
prompt_prefix="Uncover cutting-edge developments in AI and data science",
description="AI researcher",
tools=[DuckDuckGoSearchRun()],
tools=[MockTool()],
verbose=True,
)
return agent
Expand All @@ -65,7 +65,7 @@ def agent(self, request, crewai_agent, langchain_agent, llama_index_agent):
@pytest.mark.parametrize("agent", test_agents_names, indirect=True)
def test_add_tools(self, agent):
assert len(agent.tools) == 1
tools = [DuckDuckGoSearchRun()]
tools = [MockTool()]
agent.add_tools(tools)
assert len(agent.tools) == 1

Expand Down
42 changes: 29 additions & 13 deletions tests/test_agents/test_langchain_output_handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.agents import AgentFinish, AgentAction

from motleycrew.agents import MotleyOutputHandler
from motleycrew.agents.langchain.tool_calling_react import ReActToolCallingAgent
from motleycrew.agents.parent import DirectOutput
from motleycrew.common.exceptions import InvalidOutput
from motleycrew.common.exceptions import InvalidOutput, OutputHandlerMaxIterationsExceeded
from tests.test_agents import MockTool

invalid_output = "Add more information about AI applications in medicine."

Expand Down Expand Up @@ -38,10 +38,10 @@ def fake_agent_take_next_step(
@pytest.fixture
def agent():
agent = ReActToolCallingAgent(
tools=[DuckDuckGoSearchRun()],
tools=[MockTool()],
verbose=True,
chat_history=True,
output_handler=ReportOutputHandler(),
output_handler=ReportOutputHandler(max_iterations=5),
)
agent.materialize()
object.__setattr__(agent._agent, "plan", fake_agent_plan)
Expand All @@ -56,6 +56,19 @@ def agent():
return agent


@pytest.fixture
def run_kwargs(agent):
agent_executor = agent.agent.bound.bound.steps[1].bound

run_kwargs = {
"name_to_tool_map": {tool.name: tool for tool in agent_executor.tools},
"color_mapping": {},
"inputs": {},
"intermediate_steps": [],
}
return run_kwargs


def test_agent_plan(agent):
agent_executor = agent.agent
agent_action = AgentAction("tool", "tool_input", "tool_log")
Expand All @@ -71,15 +84,7 @@ def test_agent_plan(agent):
assert step.tool_input == "test_output"


def test_agent_take_next_step(agent):
agent_executor = agent.agent.bound.bound.steps[1].bound

run_kwargs = {
"name_to_tool_map": {tool.name: tool for tool in agent_executor.tools},
"color_mapping": {},
"inputs": {},
"intermediate_steps": [],
}
def test_agent_take_next_step(agent, run_kwargs):

# test wrong output
input_data = "Latest advancements in AI in 2024."
Expand All @@ -95,3 +100,14 @@ def test_agent_take_next_step(agent):
assert isinstance(step_result.return_values, dict)
output_result = step_result.return_values.get("output")
assert output_result == {"checked_output": input_data}


def test_output_handler_max_iteration(agent, run_kwargs):
input_data = "Latest advancements in AI in 2024."
run_kwargs["inputs"] = input_data

with pytest.raises(OutputHandlerMaxIterationsExceeded):
for iteration in range(agent.output_handler.max_iterations + 1):
agent.agent._take_next_step(**run_kwargs)

assert iteration == agent.output_handler.max_iterations
Loading

0 comments on commit a4c2fdd

Please sign in to comment.