Skip to content

Commit

Permalink
Support agent app & async tools (#89)
Browse files Browse the repository at this point in the history
* Support agent demo app

* Fixes

* Fix tests

* Downgrade Kuzu
  • Loading branch information
whimo authored Oct 14, 2024
1 parent 2323835 commit 3f357a1
Show file tree
Hide file tree
Showing 23 changed files with 2,088 additions and 922 deletions.
261 changes: 57 additions & 204 deletions examples/Multi-step research agent.ipynb

Large diffs are not rendered by default.

23 changes: 22 additions & 1 deletion motleycrew/agents/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,22 @@ def materialize(self):
object.__setattr__(
self._agent.agent, "plan", self.agent_plan_decorator(self._agent.agent.plan)
)
if hasattr(self._agent.agent, "aplan"):
object.__setattr__(
self._agent.agent, "aplan", self.agent_aplan_decorator(self._agent.agent.aplan)
)

object.__setattr__(
self._agent,
"_take_next_step",
self.take_next_step_decorator(self._agent._take_next_step),
)
if hasattr(self._agent, "_atake_next_step"):
object.__setattr__(
self._agent,
"_atake_next_step",
self.take_next_step_decorator(self._agent._atake_next_step),
)

for tool in self.agent.tools:
if tool.return_direct:
Expand All @@ -130,12 +140,23 @@ def materialize(self):
"_run",
self._run_tool_direct_decorator(tool._run),
)

object.__setattr__(
tool,
"run",
self.run_tool_direct_decorator(tool.run),
)
if hasattr(tool, "_arun"):
object.__setattr__(
tool,
"_arun",
self._run_tool_direct_decorator(tool._arun),
)
if hasattr(tool, "arun"):
object.__setattr__(
tool,
"arun",
self.run_tool_direct_decorator(tool.arun),
)

if self.get_session_history_callable:
logger.info("Wrapping agent in RunnableWithMessageHistory")
Expand Down
3 changes: 2 additions & 1 deletion motleycrew/agents/llama_index/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
)

self.direct_output = None
self.aux_prompts = AuxPrompts()

def _propagate_error_step(self, task_id: str, message: str):
error_step = TaskStep(
Expand Down Expand Up @@ -133,7 +134,7 @@ def wrapper(
cur_step_output.is_last = False
self._propagate_error_step(
task_id=cur_step_output.task_step.task_id,
message=AuxPrompts.get_direct_output_error_message(
message=self.aux_prompts.get_direct_output_error_message(
output_handlers=output_handlers
),
)
Expand Down
135 changes: 118 additions & 17 deletions motleycrew/agents/mixins.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Any, Optional, Callable, Union, Dict, List, Tuple
import asyncio
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from langchain_core.agents import AgentFinish, AgentAction
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool, StructuredTool
from motleycrew.tools import MotleyTool, DirectOutput

from motleycrew.common import AuxPrompts
from motleycrew.tools import DirectOutput, MotleyTool


class LangchainOutputHandlingAgentMixin:
Expand All @@ -15,6 +20,7 @@ class LangchainOutputHandlingAgentMixin:
_agent_error_tool: Optional[BaseTool] = None
get_output_handlers: Callable[[], List[MotleyTool]] = None
force_output_handler: bool = False
aux_prompts: AuxPrompts = AuxPrompts()

def _create_agent_error_tool(self) -> BaseTool:
"""Create a tool that will force the agent to retry if it attempts to return the output
Expand Down Expand Up @@ -63,7 +69,7 @@ def wrapper(
if self._is_error_action(action):
# Add the interaction telling the LLM that it errored
additional_notes.append(("ai", action.tool_input["message"]))
additional_notes.append(("user", action_output))
additional_notes.append(("system", action_output))
to_remove_steps.append(intermediate_step)

for to_remove_step in to_remove_steps:
Expand All @@ -82,7 +88,9 @@ def wrapper(
# Attempted to return output directly, blocking
return self._create_error_action(
message=step.log,
error_message=AuxPrompts.get_direct_output_error_message(output_handlers),
error_message=self.aux_prompts.get_direct_output_error_message(
output_handlers
),
)
try:
step = list(step)
Expand All @@ -98,7 +106,67 @@ def wrapper(
# Attempted to call multiple output handlers or included other tool calls, blocking
return self._create_error_action(
message=step.log,
error_message=AuxPrompts.get_ambiguous_output_handler_call_error_message(
error_message=self.aux_prompts.get_ambiguous_output_handler_call_error_message(
current_output_handler=action.tool, output_handlers=output_handlers
),
)
return step

return wrapper

def agent_aplan_decorator(self, func: Callable):
"""Decorator for Agent.aplan() method that intercepts AgentFinish events"""

output_handlers = self.get_output_handlers()
output_handler_names = set(handler.name for handler in output_handlers)

async def wrapper(
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: "Callbacks" = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
additional_notes = []

to_remove_steps = []
for intermediate_step in intermediate_steps:
action, action_output = intermediate_step
if self._is_error_action(action):
additional_notes.append(("ai", action.tool_input["message"]))
additional_notes.append(("system", action_output))
to_remove_steps.append(intermediate_step)

for to_remove_step in to_remove_steps:
intermediate_steps.remove(to_remove_step)

if additional_notes:
kwargs["additional_notes"] = additional_notes

step = await func(intermediate_steps, callbacks, **kwargs)

if isinstance(step, AgentAction):
step = [step]

if output_handlers:
if isinstance(step, AgentFinish) and self.force_output_handler:
return self._create_error_action(
message=step.log,
error_message=self.aux_prompts.get_direct_output_error_message(
output_handlers
),
)
try:
step = list(step)
except TypeError:
return step

if len(step) <= 1:
return step

for action in step:
if action.tool in output_handler_names:
return self._create_error_action(
message=step.log,
error_message=self.aux_prompts.get_ambiguous_output_handler_call_error_message(
current_output_handler=action.tool, output_handlers=output_handlers
),
)
Expand All @@ -108,17 +176,37 @@ def wrapper(

def take_next_step_decorator(self, func: Callable):
"""
Decorator for ``AgentExecutor._take_next_step()`` method that catches DirectOutput exceptions.
Decorator for ``AgentExecutor._take_next_step()`` and ``AgentExecutor._atake_next_step()`` methods
that catches DirectOutput exceptions.
"""

def wrapper(
async def async_wrapper(
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
try:
step = await func(
name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager
)
except DirectOutput as direct_ex:
message = str(direct_ex.output)
return AgentFinish(
return_values={"output": direct_ex.output},
messages=[AIMessage(content=message)],
log=message,
)
return step

def sync_wrapper(
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
try:
step = func(
name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager
Expand All @@ -132,27 +220,40 @@ def wrapper(
)
return step

return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper

def _run_tool_direct_decorator(self, func: Callable):
"""Decorator of the tool's _run method, for intercepting a DirectOutput exception"""
"""Decorator of the tool's _run and _arun methods, for intercepting a DirectOutput exception"""

async def async_wrapper(*args, config: RunnableConfig, **kwargs):
try:
result = await func(*args, **kwargs, config=config)
except DirectOutput as direct_exc:
return direct_exc
return result

def wrapper(*args, config: RunnableConfig, **kwargs):
def sync_wrapper(*args, config: RunnableConfig, **kwargs):
try:
result = func(*args, **kwargs, config=config)
except DirectOutput as direct_exc:
return direct_exc
return result

return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper

def run_tool_direct_decorator(self, func: Callable):
"""Decorator of the tool's run method, for intercepting a DirectOutput exception"""
"""Decorator of the tool's run and arun methods, for intercepting a DirectOutput exception"""

def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
if isinstance(result, DirectOutput):
raise result
return result

def sync_wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if isinstance(result, DirectOutput):
raise result
return result

return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
22 changes: 22 additions & 0 deletions motleycrew/applications/customer_support/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Customer support agent demo

This is a demo of a customer support app built using motleycrew and Ray.

It includes sample data for populating the issue tree.


## Installation and usage
We suggest you set up a virtualenv for managing the environment.

```
git clone https://github.com/ShoggothAI/motleycrew.git
cd motleycrew
pip install -r requirements.txt
python -m motleycrew.applications.customer_support.issue_tree # populate the issue tree
ray start --head
python -m motleycrew.applications.customer_support.ray_serve_app
```

Navigate to http://127.0.0.1:8000/ and have fun!
Also, check out the Ray dashboard for the app logs etc.
67 changes: 67 additions & 0 deletions motleycrew/applications/customer_support/communication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from abc import ABC, abstractmethod
import asyncio


class CommunicationInterface(ABC):
@abstractmethod
async def send_message_to_customer(self, message: str) -> str:
"""
Send a message to the customer and return their response.
Args:
message (str): The message to send to the customer.
Returns:
str: The customer's response.
"""
pass

@abstractmethod
def escalate_to_human_agent(self) -> None:
"""
Escalate the current issue to a human agent.
"""
pass

@abstractmethod
def resolve_issue(self, resolution: str) -> str:
"""
Resolve the current issue.
Args:
resolution (str): The resolution to the issue.
Returns:
str: The resolution to the issue.
"""
pass


class DummyCommunicationInterface(CommunicationInterface):
async def send_message_to_customer(self, message: str) -> str:
print(f"Message sent to customer: {message}")
return await asyncio.to_thread(input, "Enter customer's response: ")

def escalate_to_human_agent(self) -> None:
print("Issue escalated to human agent.")

def resolve_issue(self, resolution: str) -> str:
print(f"Proposed resolution: {resolution}")
confirmation = input("Is the issue resolved? (y/n): ")
if confirmation.lower().startswith("y"):
return "Issue resolved"
else:
self.escalate_to_human_agent()


# Placeholder for future implementation
class RealCommunicationInterface(CommunicationInterface):
async def send_message_to_customer(self, message: str) -> str:
# TODO: Implement real asynchronous communication with the customer
# This could involve integrating with a chat system, email, or other communication channels
pass

def escalate_to_human_agent(self) -> None:
# TODO: Implement real escalation to a human agent
# This could involve creating a ticket in a support system or notifying a human agent directly
pass
Loading

0 comments on commit 3f357a1

Please sign in to comment.