Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoGen integration #23

Merged
merged 6 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
756 changes: 756 additions & 0 deletions examples/Using AutoGen chats with motleycrew.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions examples/math_crewai.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from dotenv import load_dotenv

from motleycrew import MotleyCrew, Task
from motleycrew import MotleyCrew
from motleycrew.agents.crewai import CrewAIMotleyAgent
from motleycrew.tools.python_repl import create_repl_tool
from motleycrew.tools import PythonREPLTool
from motleycrew.common.utils import configure_logging


def main():
"""Main function of running the example."""
repl_tool = create_repl_tool()
repl_tool = PythonREPLTool()

# Define your agents with roles and goals
solver1 = CrewAIMotleyAgent(
Expand Down
10 changes: 5 additions & 5 deletions motleycrew/common/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class Defaults:

DEFAULT_GRAPH_STORE_TYPE = GraphStoreType.KUZU


defaults_module_install_commands = {
"crewai": "pip install crewai",
"llama_index": "pip install llama-index"
}
MODULE_INSTALL_COMMANDS = {
"crewai": "pip install crewai",
"llama_index": "pip install llama-index",
"autogen": "pip install autogen",
}
4 changes: 2 additions & 2 deletions motleycrew/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from motleycrew.common.defaults import defaults_module_install_commands
from motleycrew.common import Defaults


class LLMFamilyNotSupported(Exception):
Expand Down Expand Up @@ -55,7 +55,7 @@ class ModuleNotInstalledException(Exception):

def __init__(self, module_name: str, install_command: str = None):
self.module_name = module_name
self.install_command = install_command or defaults_module_install_commands.get(
self.install_command = install_command or Defaults.MODULE_INSTALL_COMMANDS.get(
module_name, None
)

Expand Down
6 changes: 5 additions & 1 deletion motleycrew/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .tool import MotleyTool
from .llm_tool import LLMTool

from .autogen_chat_tool import AutoGenChatTool
from .image_generation import DallEImageGeneratorTool
from .llm_tool import LLMTool
from .mermaid_evaluator_tool import MermaidEvaluatorTool
from .python_repl import PythonREPLTool
79 changes: 79 additions & 0 deletions motleycrew/tools/autogen_chat_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Optional, Type, Callable, Any

from langchain_core.tools import StructuredTool
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, create_model

try:
from autogen import ConversableAgent, ChatResult
except ImportError:
ConversableAgent = None
ChatResult = None

from motleycrew.tools import MotleyTool
from motleycrew.common.utils import ensure_module_is_installed


def get_last_message(chat_result: ChatResult) -> str:
for message in reversed(chat_result.chat_history):
if message.get("content") and "TERMINATE" not in message["content"]:
return message["content"]


class AutoGenChatTool(MotleyTool):
def __init__(
self,
name: str,
description: str,
prompt: str | BasePromptTemplate,
initiator: ConversableAgent,
recipient: ConversableAgent,
result_extractor: Callable[[ChatResult], Any] = get_last_message,
input_schema: Optional[Type[BaseModel]] = None,
):
ensure_module_is_installed("autogen")
langchain_tool = create_autogen_chat_tool(
name=name,
description=description,
prompt=prompt,
initiator=initiator,
recipient=recipient,
result_extractor=result_extractor,
input_schema=input_schema,
)
super().__init__(langchain_tool)


def create_autogen_chat_tool(
name: str,
description: str,
prompt: str | BasePromptTemplate,
initiator: ConversableAgent,
recipient: ConversableAgent,
result_extractor: Callable[[ChatResult], Any],
input_schema: Optional[Type[BaseModel]] = None,
):
if not isinstance(prompt, BasePromptTemplate):
prompt = PromptTemplate.from_template(prompt)

if input_schema is None:
fields = {
var: (str, Field(description=f"Input {var} for the tool."))
for var in prompt.input_variables
}

# Create the AutoGenChatToolInput class dynamically
input_schema = create_model("AutoGenChatToolInput", **fields)

def run_autogen_chat(**kwargs) -> Any:
message = prompt.format(**kwargs)
chat_result = initiator.initiate_chat(recipient, message=message)
return result_extractor(chat_result)

return StructuredTool.from_function(
func=run_autogen_chat,
name=name,
description=description,
args_schema=input_schema,
)
20 changes: 12 additions & 8 deletions motleycrew/tools/python_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from .tool import MotleyTool


class PythonREPLTool(MotleyTool):
def __init__(self):
langchain_tool = create_repl_tool()
super().__init__(langchain_tool)


class REPLToolInput(BaseModel):
"""Input for the REPL tool."""

Expand All @@ -13,12 +19,10 @@ class REPLToolInput(BaseModel):

# You can create the tool to pass to an agent
def create_repl_tool():
return MotleyTool.from_langchain_tool(
Tool.from_function(
func=PythonREPL().run,
name="python_repl",
description="A Python shell. Use this to execute python commands. Input should be a valid python command. "
"MAKE SURE TO PRINT OUT THE RESULTS YOU CARE ABOUT USING `print(...)`.",
args_schema=REPLToolInput,
)
return Tool.from_function(
func=PythonREPL().run,
name="python_repl",
description="A Python shell. Use this to execute python commands. Input should be a valid python command. "
"MAKE SURE TO PRINT OUT THE RESULTS YOU CARE ABOUT USING `print(...)`.",
args_schema=REPLToolInput,
)
15 changes: 14 additions & 1 deletion motleycrew/tools/tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Annotated

from langchain.tools import BaseTool
from langchain_core.runnables import Runnable
Expand Down Expand Up @@ -76,3 +76,16 @@ def to_llama_index_tool(self) -> LlamaIndex__BaseTool:
fn_schema=self.tool.args_schema,
)
return llama_index_tool

def to_autogen_tool(self):
fields = list(self.tool.args_schema.__fields__.values())
if len(fields) != 1:
raise Exception("Multiple input fields are not supported in to_autogen_tool")

field_name = fields[0].name
field_type = fields[0].annotation

def autogen_tool_fn(input: field_type) -> str:
return self.invoke({field_name: input})

return autogen_tool_fn
Loading
Loading