Skip to content

Commit

Permalink
add docstrings for tools package
Browse files Browse the repository at this point in the history
  • Loading branch information
User committed May 29, 2024
1 parent 59b0d73 commit 8bdb83a
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 7 deletions.
34 changes: 34 additions & 0 deletions motleycrew/tools/autogen_chat_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Module description """
from typing import Optional, Type, Callable, Any

from langchain_core.tools import StructuredTool
Expand All @@ -16,6 +17,14 @@


def get_last_message(chat_result: ChatResult) -> str:
""" Description
Args:
chat_result (ChatResult):
Returns:
str:
"""
for message in reversed(chat_result.chat_history):
if message.get("content") and "TERMINATE" not in message["content"]:
return message["content"]
Expand All @@ -32,6 +41,17 @@ def __init__(
result_extractor: Callable[[ChatResult], Any] = get_last_message,
input_schema: Optional[Type[BaseModel]] = None,
):
""" Description
Args:
name (str):
description (str):
prompt (:obj:`str`, :obj:`BasePromptTemplate`):
initiator (ConversableAgent):
recipient (ConversableAgent):
result_extractor (:obj:`Callable[[ChatResult]`, :obj:`Any`, optional):
input_schema (:obj:`Type[BaseModel]`, optional):
"""
ensure_module_is_installed("autogen")
langchain_tool = create_autogen_chat_tool(
name=name,
Expand All @@ -54,6 +74,20 @@ def create_autogen_chat_tool(
result_extractor: Callable[[ChatResult], Any],
input_schema: Optional[Type[BaseModel]] = None,
):
""" Description
Args:
name (str):
description (str):
prompt (:obj:`str`, :obj:`BasePromptTemplate`):
initiator (ConversableAgent):
recipient (ConversableAgent):
result_extractor (:obj:`Callable[[ChatResult]`, :obj:`Any`, optional):
input_schema (:obj:`Type[BaseModel]`, optional):
Returns:
"""
if not isinstance(prompt, BasePromptTemplate):
prompt = PromptTemplate.from_template(prompt)

Expand Down
59 changes: 58 additions & 1 deletion motleycrew/tools/image/dall_e.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
""" Module description
Attributes:
prompt_template (str):
dall_e_template (str):
"""
from typing import Optional

import os
Expand All @@ -17,6 +23,15 @@


def download_image(url: str, file_path: str) -> Optional[str]:
""" Description
Args:
url (str):
file_path (str):
Returns:
:obj:`str`, None:
"""
response = requests.get(url, stream=True)
if response.status_code == requests.codes.ok:
content_type = response.headers.get("content-type")
Expand Down Expand Up @@ -46,6 +61,16 @@ def __init__(
size: str = "1024x1024",
style: Optional[str] = None,
):
""" Description
Args:
images_directory (:obj:`str`, optional):
refine_prompt_with_llm (:obj:`bool`, optional):
model (:obj:`str`, optional):
quality (:obj:`str`, optional):
size (:obj:`str`, optional):
style (:obj:`str`, optional):
"""
langchain_tool = create_dalle_image_generator_langchain_tool(
images_directory=images_directory,
refine_prompt_with_llm=refine_prompt_with_llm,
Expand All @@ -58,7 +83,11 @@ def __init__(


class DallEToolInput(BaseModel):
"""Input for the Dall-E tool."""
"""Input for the Dall-E tool.
Attributes:
description (str):
"""

description: str = Field(description="image description")

Expand All @@ -83,6 +112,21 @@ def run_dalle_and_save_images(
style: Optional[str] = None,
file_name_length: int = 8,
) -> Optional[list[str]]:
""" Description
Args:
description (str):
images_directory (:obj:`str`, optional):
refine_prompt_with_llm(:obj:`bool`, optional):
model (:obj:`str`, optional):
quality (:obj:`str`, optional):
size (:obj:`str`, optional):
style (:obj:`str`, optional):
file_name_length (:obj:`int`, optional):
Returns:
:obj:`list` of :obj:`str`:
"""

dall_e_prompt = PromptTemplate.from_template(dall_e_template)

Expand Down Expand Up @@ -133,6 +177,19 @@ def create_dalle_image_generator_langchain_tool(
size: str = "1024x1024",
style: Optional[str] = None,
):
""" Description
Args:
images_directory (:obj:`str`, optional):
refine_prompt_with_llm (:obj:`bool`, optional):
model (:obj:`str`, optional):
quality (:obj:`str`, optional):
size (:obj:`str`, optional):
style (:obj:`str`, optional):
Returns:
Tool:
"""
def run_dalle_and_save_images_partial(description: str):
return run_dalle_and_save_images(
description=description,
Expand Down
22 changes: 22 additions & 0 deletions motleycrew/tools/llm_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Module description"""
from typing import Optional, Type

from langchain_core.tools import StructuredTool
Expand All @@ -20,6 +21,15 @@ def __init__(
llm: Optional[BaseLanguageModel] = None,
input_schema: Optional[Type[BaseModel]] = None,
):
""" Description
Args:
name (str):
description (str):
prompt (:obj:`str`, :obj:`BasePromptTemplate`):
llm (:obj:`BaseLanguageModel`, optional):
input_schema (:obj:`Type[BaseModel]`, optional):
"""
langchain_tool = create_llm_langchain_tool(
name=name,
description=description,
Expand All @@ -37,6 +47,18 @@ def create_llm_langchain_tool(
llm: Optional[BaseLanguageModel] = None,
input_schema: Optional[Type[BaseModel]] = None,
):
""" Description
Args:
name (str):
description (str):
prompt (:obj:`str`, :obj:`BasePromptTemplate`):
llm (:obj:`BaseLanguageModel`, optional):
input_schema (:obj:`Type[BaseModel]`, optional):
Returns:
"""
if llm is None:
llm = init_llm(llm_framework=LLMFramework.LANGCHAIN)

Expand Down
15 changes: 15 additions & 0 deletions motleycrew/tools/mermaid_evaluator_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Module description """
# https://nodejs.org/en/download
# npm install -g @mermaid-js/mermaid-cli
import os.path
Expand All @@ -13,6 +14,11 @@

class MermaidEvaluatorTool(MotleyTool):
def __init__(self, format: Optional[str] = "svg"):
""" Description
Args:
format (:obj:`str`, None):
"""
def eval_mermaid_partial(mermaid_code: str):
return eval_mermaid(mermaid_code, format)

Expand All @@ -29,6 +35,15 @@ def eval_mermaid_partial(mermaid_code: str):


def eval_mermaid(mermaid_code: str, format: Optional[str] = "svg") -> io.BytesIO:
""" Description
Args:
mermaid_code (str):
format (:obj:`str`, optional):
Returns:
io.BytesIO:
"""
with tempfile.NamedTemporaryFile(delete=True, mode="w+", suffix=".mmd") as temp_in:
temp_in.write(mermaid_code)
temp_in.flush() # Ensure all data is written to disk
Expand Down
15 changes: 14 additions & 1 deletion motleycrew/tools/python_repl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Module description """
from langchain.agents import Tool
from langchain_experimental.utilities import PythonREPL
from langchain_core.pydantic_v1 import BaseModel, Field
Expand All @@ -7,18 +8,30 @@

class PythonREPLTool(MotleyTool):
def __init__(self):
""" Description
"""
langchain_tool = create_repl_tool()
super().__init__(langchain_tool)


class REPLToolInput(BaseModel):
"""Input for the REPL tool."""
"""Input for the REPL tool.
Attributes:
command (str):
"""

command: str = Field(description="code to execute")


# You can create the tool to pass to an agent
def create_repl_tool():
""" Description
Returns:
Tool:
"""
return Tool.from_function(
func=PythonREPL().run,
name="python_repl",
Expand Down
25 changes: 24 additions & 1 deletion motleycrew/tools/simple_retriever_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Module description """
import os.path

from langchain_core.pydantic_v1 import BaseModel, Field
Expand All @@ -19,21 +20,43 @@

class SimpleRetrieverTool(MotleyTool):
def __init__(self, DATA_DIR, PERSIST_DIR, return_strings_only: bool = False):
""" Description
Args:
DATA_DIR (str):
PERSIST_DIR (str):
return_strings_only (:obj:`bool`, optional):
"""
tool = make_retriever_langchain_tool(
DATA_DIR, PERSIST_DIR, return_strings_only=return_strings_only
)
super().__init__(tool)


class RetrieverToolInput(BaseModel, arbitrary_types_allowed=True):
"""Input for the Retriever Tool."""
"""Input for the Retriever Tool.
Attributes:
question (Question):
"""

question: Question = Field(
description="The input question for which to retrieve relevant data."
)


def make_retriever_langchain_tool(DATA_DIR, PERSIST_DIR, return_strings_only: bool = False):
""" Description
Args:
DATA_DIR (str):
PERSIST_DIR (str):
return_strings_only (:obj:`bool`, optional):
Returns:
"""
text_embedding_model = "text-embedding-ada-002"
embeddings = OpenAIEmbedding(model=text_embedding_model)

Expand Down
Loading

0 comments on commit 8bdb83a

Please sign in to comment.