Skip to content

Commit

Permalink
Question prioritizer first cut
Browse files Browse the repository at this point in the history
  • Loading branch information
ZmeiGorynych committed Apr 26, 2024
1 parent 6d33b69 commit 330e692
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 28 deletions.
Empty file.
10 changes: 0 additions & 10 deletions motleycrew/agent/coordinator.py

This file was deleted.

2 changes: 2 additions & 0 deletions motleycrew/tool/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .tool import MotleyTool
from .llm_tool import LLMTool
from .image_generation import DallEImageGeneratorTool
40 changes: 22 additions & 18 deletions motleycrew/tool/llm_tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional
from typing import Optional, Type

from langchain_core.tools import Tool
from langchain_core.tools import StructuredTool
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.pydantic_v1 import BaseModel, Field, create_model

from motleycrew.tool import MotleyTool
from motleycrew.common import LLMFramework
Expand All @@ -18,10 +18,14 @@ def __init__(
description: str,
prompt: str | BasePromptTemplate,
llm: Optional[BaseLanguageModel] = None,
input_description: Optional[str] = "Input for the tool.",
input_schema: Optional[Type[BaseModel]] = None,
):
langchain_tool = create_llm_langchain_tool(
name=name, description=description, prompt=prompt, llm=llm, input_description=input_description
name=name,
description=description,
prompt=prompt,
llm=llm,
input_schema=input_schema,
)
super().__init__(langchain_tool)

Expand All @@ -30,31 +34,31 @@ def create_llm_langchain_tool(
name: str,
description: str,
prompt: str | BasePromptTemplate,
llm: Optional[BaseLanguageModel],
input_description: Optional[str],
llm: Optional[BaseLanguageModel] = None,
input_schema: Optional[Type[BaseModel]] = None,
):
if llm is None:
llm = init_llm(llm_framework=LLMFramework.LANGCHAIN)

if not isinstance(prompt, BasePromptTemplate):
prompt = PromptTemplate.from_template(prompt)

assert len(prompt.input_variables) == 1, "Prompt must contain exactly one input variable"
input_var = prompt.input_variables[0]
if input_schema is None:
fields = {
var: (str, Field(description=f"Input {var} for the tool."))
for var in prompt.input_variables
}

class LLMToolInput(BaseModel):
"""Input for the tool."""
# Create the LLMToolInput class dynamically
input_schema = create_model("LLMToolInput", **fields)

# TODO: how hard is it to get that name from prompt.input_variables?
input: str = Field(description=input_description)

def call_llm(input: str) -> str:
def call_llm(**kwargs) -> str:
chain = prompt | llm
return chain.invoke({input_var: input})
return chain.invoke(kwargs)

return Tool.from_function(
return StructuredTool.from_function(
func=call_llm,
name=name,
description=description,
args_schema=LLMToolInput,
args_schema=input_schema,
)

0 comments on commit 330e692

Please sign in to comment.