diff --git a/examples/research_agent/question_answerer.py b/examples/research_agent/question_answerer.py new file mode 100644 index 00000000..e69de29b diff --git a/motleycrew/agent/coordinator.py b/motleycrew/agent/coordinator.py deleted file mode 100644 index 237325b7..00000000 --- a/motleycrew/agent/coordinator.py +++ /dev/null @@ -1,10 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Sequence - -from motleycrew.task import Task - - -class TaskCoordinator(ABC): - @abstractmethod - def order(self, tasks: Sequence[Task]) -> Sequence[Task]: - pass diff --git a/motleycrew/tool/__init__.py b/motleycrew/tool/__init__.py index 84596777..a9e7456e 100644 --- a/motleycrew/tool/__init__.py +++ b/motleycrew/tool/__init__.py @@ -1 +1,3 @@ from .tool import MotleyTool +from .llm_tool import LLMTool +from .image_generation import DallEImageGeneratorTool diff --git a/motleycrew/tool/llm_tool.py b/motleycrew/tool/llm_tool.py index 1b024abd..d2567303 100644 --- a/motleycrew/tool/llm_tool.py +++ b/motleycrew/tool/llm_tool.py @@ -1,10 +1,10 @@ -from typing import Optional +from typing import Optional, Type -from langchain_core.tools import Tool +from langchain_core.tools import StructuredTool from langchain_core.prompts import PromptTemplate from langchain_core.prompts.base import BasePromptTemplate from langchain_core.language_models import BaseLanguageModel -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field, create_model from motleycrew.tool import MotleyTool from motleycrew.common import LLMFramework @@ -18,10 +18,14 @@ def __init__( description: str, prompt: str | BasePromptTemplate, llm: Optional[BaseLanguageModel] = None, - input_description: Optional[str] = "Input for the tool.", + input_schema: Optional[Type[BaseModel]] = None, ): langchain_tool = create_llm_langchain_tool( - name=name, description=description, prompt=prompt, llm=llm, input_description=input_description + name=name, + description=description, + prompt=prompt, + llm=llm, + input_schema=input_schema, ) super().__init__(langchain_tool) @@ -30,8 +34,8 @@ def create_llm_langchain_tool( name: str, description: str, prompt: str | BasePromptTemplate, - llm: Optional[BaseLanguageModel], - input_description: Optional[str], + llm: Optional[BaseLanguageModel] = None, + input_schema: Optional[Type[BaseModel]] = None, ): if llm is None: llm = init_llm(llm_framework=LLMFramework.LANGCHAIN) @@ -39,22 +43,22 @@ def create_llm_langchain_tool( if not isinstance(prompt, BasePromptTemplate): prompt = PromptTemplate.from_template(prompt) - assert len(prompt.input_variables) == 1, "Prompt must contain exactly one input variable" - input_var = prompt.input_variables[0] + if input_schema is None: + fields = { + var: (str, Field(description=f"Input {var} for the tool.")) + for var in prompt.input_variables + } - class LLMToolInput(BaseModel): - """Input for the tool.""" + # Create the LLMToolInput class dynamically + input_schema = create_model("LLMToolInput", **fields) - # TODO: how hard is it to get that name from prompt.input_variables? - input: str = Field(description=input_description) - - def call_llm(input: str) -> str: + def call_llm(**kwargs) -> str: chain = prompt | llm - return chain.invoke({input_var: input}) + return chain.invoke(kwargs) - return Tool.from_function( + return StructuredTool.from_function( func=call_llm, name=name, description=description, - args_schema=LLMToolInput, + args_schema=input_schema, )