From 8bdb83a890db10f8a931e663a02311fa6114839e Mon Sep 17 00:00:00 2001 From: User Date: Wed, 29 May 2024 11:39:58 +0300 Subject: [PATCH] add docstrings for tools package --- motleycrew/tools/autogen_chat_tool.py | 34 +++++++++++ motleycrew/tools/image/dall_e.py | 59 ++++++++++++++++++- motleycrew/tools/llm_tool.py | 22 +++++++ motleycrew/tools/mermaid_evaluator_tool.py | 15 +++++ motleycrew/tools/python_repl.py | 15 ++++- motleycrew/tools/simple_retriever_tool.py | 25 +++++++- motleycrew/tools/tool.py | 68 ++++++++++++++++++++-- 7 files changed, 231 insertions(+), 7 deletions(-) diff --git a/motleycrew/tools/autogen_chat_tool.py b/motleycrew/tools/autogen_chat_tool.py index fd6065db..f3e46f49 100644 --- a/motleycrew/tools/autogen_chat_tool.py +++ b/motleycrew/tools/autogen_chat_tool.py @@ -1,3 +1,4 @@ +""" Module description """ from typing import Optional, Type, Callable, Any from langchain_core.tools import StructuredTool @@ -16,6 +17,14 @@ def get_last_message(chat_result: ChatResult) -> str: + """ Description + + Args: + chat_result (ChatResult): + + Returns: + str: + """ for message in reversed(chat_result.chat_history): if message.get("content") and "TERMINATE" not in message["content"]: return message["content"] @@ -32,6 +41,17 @@ def __init__( result_extractor: Callable[[ChatResult], Any] = get_last_message, input_schema: Optional[Type[BaseModel]] = None, ): + """ Description + + Args: + name (str): + description (str): + prompt (:obj:`str`, :obj:`BasePromptTemplate`): + initiator (ConversableAgent): + recipient (ConversableAgent): + result_extractor (:obj:`Callable[[ChatResult]`, :obj:`Any`, optional): + input_schema (:obj:`Type[BaseModel]`, optional): + """ ensure_module_is_installed("autogen") langchain_tool = create_autogen_chat_tool( name=name, @@ -54,6 +74,20 @@ def create_autogen_chat_tool( result_extractor: Callable[[ChatResult], Any], input_schema: Optional[Type[BaseModel]] = None, ): + """ Description + + Args: + name (str): + description (str): + prompt (:obj:`str`, :obj:`BasePromptTemplate`): + initiator (ConversableAgent): + recipient (ConversableAgent): + result_extractor (:obj:`Callable[[ChatResult]`, :obj:`Any`, optional): + input_schema (:obj:`Type[BaseModel]`, optional): + + Returns: + + """ if not isinstance(prompt, BasePromptTemplate): prompt = PromptTemplate.from_template(prompt) diff --git a/motleycrew/tools/image/dall_e.py b/motleycrew/tools/image/dall_e.py index 2c967908..064785ec 100644 --- a/motleycrew/tools/image/dall_e.py +++ b/motleycrew/tools/image/dall_e.py @@ -1,3 +1,9 @@ +""" Module description + +Attributes: + prompt_template (str): + dall_e_template (str): +""" from typing import Optional import os @@ -17,6 +23,15 @@ def download_image(url: str, file_path: str) -> Optional[str]: + """ Description + + Args: + url (str): + file_path (str): + + Returns: + :obj:`str`, None: + """ response = requests.get(url, stream=True) if response.status_code == requests.codes.ok: content_type = response.headers.get("content-type") @@ -46,6 +61,16 @@ def __init__( size: str = "1024x1024", style: Optional[str] = None, ): + """ Description + + Args: + images_directory (:obj:`str`, optional): + refine_prompt_with_llm (:obj:`bool`, optional): + model (:obj:`str`, optional): + quality (:obj:`str`, optional): + size (:obj:`str`, optional): + style (:obj:`str`, optional): + """ langchain_tool = create_dalle_image_generator_langchain_tool( images_directory=images_directory, refine_prompt_with_llm=refine_prompt_with_llm, @@ -58,7 +83,11 @@ def __init__( class DallEToolInput(BaseModel): - """Input for the Dall-E tool.""" + """Input for the Dall-E tool. + + Attributes: + description (str): + """ description: str = Field(description="image description") @@ -83,6 +112,21 @@ def run_dalle_and_save_images( style: Optional[str] = None, file_name_length: int = 8, ) -> Optional[list[str]]: + """ Description + + Args: + description (str): + images_directory (:obj:`str`, optional): + refine_prompt_with_llm(:obj:`bool`, optional): + model (:obj:`str`, optional): + quality (:obj:`str`, optional): + size (:obj:`str`, optional): + style (:obj:`str`, optional): + file_name_length (:obj:`int`, optional): + + Returns: + :obj:`list` of :obj:`str`: + """ dall_e_prompt = PromptTemplate.from_template(dall_e_template) @@ -133,6 +177,19 @@ def create_dalle_image_generator_langchain_tool( size: str = "1024x1024", style: Optional[str] = None, ): + """ Description + + Args: + images_directory (:obj:`str`, optional): + refine_prompt_with_llm (:obj:`bool`, optional): + model (:obj:`str`, optional): + quality (:obj:`str`, optional): + size (:obj:`str`, optional): + style (:obj:`str`, optional): + + Returns: + Tool: + """ def run_dalle_and_save_images_partial(description: str): return run_dalle_and_save_images( description=description, diff --git a/motleycrew/tools/llm_tool.py b/motleycrew/tools/llm_tool.py index adae279b..cabfc40e 100644 --- a/motleycrew/tools/llm_tool.py +++ b/motleycrew/tools/llm_tool.py @@ -1,3 +1,4 @@ +""" Module description""" from typing import Optional, Type from langchain_core.tools import StructuredTool @@ -20,6 +21,15 @@ def __init__( llm: Optional[BaseLanguageModel] = None, input_schema: Optional[Type[BaseModel]] = None, ): + """ Description + + Args: + name (str): + description (str): + prompt (:obj:`str`, :obj:`BasePromptTemplate`): + llm (:obj:`BaseLanguageModel`, optional): + input_schema (:obj:`Type[BaseModel]`, optional): + """ langchain_tool = create_llm_langchain_tool( name=name, description=description, @@ -37,6 +47,18 @@ def create_llm_langchain_tool( llm: Optional[BaseLanguageModel] = None, input_schema: Optional[Type[BaseModel]] = None, ): + """ Description + + Args: + name (str): + description (str): + prompt (:obj:`str`, :obj:`BasePromptTemplate`): + llm (:obj:`BaseLanguageModel`, optional): + input_schema (:obj:`Type[BaseModel]`, optional): + + Returns: + + """ if llm is None: llm = init_llm(llm_framework=LLMFramework.LANGCHAIN) diff --git a/motleycrew/tools/mermaid_evaluator_tool.py b/motleycrew/tools/mermaid_evaluator_tool.py index 9808899a..b25f70fa 100644 --- a/motleycrew/tools/mermaid_evaluator_tool.py +++ b/motleycrew/tools/mermaid_evaluator_tool.py @@ -1,3 +1,4 @@ +""" Module description """ # https://nodejs.org/en/download # npm install -g @mermaid-js/mermaid-cli import os.path @@ -13,6 +14,11 @@ class MermaidEvaluatorTool(MotleyTool): def __init__(self, format: Optional[str] = "svg"): + """ Description + + Args: + format (:obj:`str`, None): + """ def eval_mermaid_partial(mermaid_code: str): return eval_mermaid(mermaid_code, format) @@ -29,6 +35,15 @@ def eval_mermaid_partial(mermaid_code: str): def eval_mermaid(mermaid_code: str, format: Optional[str] = "svg") -> io.BytesIO: + """ Description + + Args: + mermaid_code (str): + format (:obj:`str`, optional): + + Returns: + io.BytesIO: + """ with tempfile.NamedTemporaryFile(delete=True, mode="w+", suffix=".mmd") as temp_in: temp_in.write(mermaid_code) temp_in.flush() # Ensure all data is written to disk diff --git a/motleycrew/tools/python_repl.py b/motleycrew/tools/python_repl.py index 55fffc8d..c03c1aac 100644 --- a/motleycrew/tools/python_repl.py +++ b/motleycrew/tools/python_repl.py @@ -1,3 +1,4 @@ +""" Module description """ from langchain.agents import Tool from langchain_experimental.utilities import PythonREPL from langchain_core.pydantic_v1 import BaseModel, Field @@ -7,18 +8,30 @@ class PythonREPLTool(MotleyTool): def __init__(self): + """ Description + + """ langchain_tool = create_repl_tool() super().__init__(langchain_tool) class REPLToolInput(BaseModel): - """Input for the REPL tool.""" + """Input for the REPL tool. + + Attributes: + command (str): + """ command: str = Field(description="code to execute") # You can create the tool to pass to an agent def create_repl_tool(): + """ Description + + Returns: + Tool: + """ return Tool.from_function( func=PythonREPL().run, name="python_repl", diff --git a/motleycrew/tools/simple_retriever_tool.py b/motleycrew/tools/simple_retriever_tool.py index 86db8312..03976294 100644 --- a/motleycrew/tools/simple_retriever_tool.py +++ b/motleycrew/tools/simple_retriever_tool.py @@ -1,3 +1,4 @@ +""" Module description """ import os.path from langchain_core.pydantic_v1 import BaseModel, Field @@ -19,6 +20,13 @@ class SimpleRetrieverTool(MotleyTool): def __init__(self, DATA_DIR, PERSIST_DIR, return_strings_only: bool = False): + """ Description + + Args: + DATA_DIR (str): + PERSIST_DIR (str): + return_strings_only (:obj:`bool`, optional): + """ tool = make_retriever_langchain_tool( DATA_DIR, PERSIST_DIR, return_strings_only=return_strings_only ) @@ -26,7 +34,12 @@ def __init__(self, DATA_DIR, PERSIST_DIR, return_strings_only: bool = False): class RetrieverToolInput(BaseModel, arbitrary_types_allowed=True): - """Input for the Retriever Tool.""" + """Input for the Retriever Tool. + + Attributes: + question (Question): + + """ question: Question = Field( description="The input question for which to retrieve relevant data." @@ -34,6 +47,16 @@ class RetrieverToolInput(BaseModel, arbitrary_types_allowed=True): def make_retriever_langchain_tool(DATA_DIR, PERSIST_DIR, return_strings_only: bool = False): + """ Description + + Args: + DATA_DIR (str): + PERSIST_DIR (str): + return_strings_only (:obj:`bool`, optional): + + Returns: + + """ text_embedding_model = "text-embedding-ada-002" embeddings = OpenAIEmbedding(model=text_embedding_model) diff --git a/motleycrew/tools/tool.py b/motleycrew/tools/tool.py index 812532e6..658eecfb 100644 --- a/motleycrew/tools/tool.py +++ b/motleycrew/tools/tool.py @@ -1,3 +1,4 @@ +""" Module description """ from typing import Union, Annotated from langchain.tools import BaseTool @@ -14,6 +15,15 @@ def normalize_input(args, kwargs): + """ Description + + Args: + args (Sequence): + kwargs (Map): + + Returns: + Any: + """ if "tool_input" in kwargs: return kwargs["tool_input"] else: @@ -21,12 +31,14 @@ def normalize_input(args, kwargs): class MotleyTool(Runnable): - """ - Base tool class compatible with MotleyAgents. - It is a wrapper for LangChain's BaseTool, containing all necessary adapters and converters. - """ def __init__(self, tool: BaseTool): + """ Base tool class compatible with MotleyAgents. + It is a wrapper for LangChain's BaseTool, containing all necessary adapters and converters. + + Args: + tool (BaseTool): + """ self.tool = tool @property @@ -35,14 +47,39 @@ def name(self): return self.tool.name def invoke(self, *args, **kwargs): + """ Description + + Args: + *args: + **kwargs: + + Returns: + Any: + """ return self.tool.invoke(*args, **kwargs) @staticmethod def from_langchain_tool(langchain_tool: BaseTool) -> "MotleyTool": + """ Description + + Args: + langchain_tool (BaseTool): + + Returns: + MotleyTool: + """ return MotleyTool(tool=langchain_tool) @staticmethod def from_llama_index_tool(llama_index_tool: LlamaIndex__BaseTool) -> "MotleyTool": + """ Description + + Args: + llama_index_tool (LlamaIndex__BaseTool): + + Returns: + MotleyTool: + """ ensure_module_is_installed("llama_index") langchain_tool = llama_index_tool.to_langchain_tool() return MotleyTool.from_langchain_tool(langchain_tool=langchain_tool) @@ -51,6 +88,14 @@ def from_llama_index_tool(llama_index_tool: LlamaIndex__BaseTool) -> "MotleyTool def from_supported_tool( tool: Union["MotleyTool", BaseTool, LlamaIndex__BaseTool, MotleyAgentAbstractParent] ): + """ Description + + Args: + tool (:obj:`MotleyTool`, :obj:`BaseTool`, :obj:`LlamaIndex__BaseTool`, :obj:`MotleyAgentAbstractParent`): + + Returns: + + """ if isinstance(tool, MotleyTool): return tool elif isinstance(tool, BaseTool): @@ -65,9 +110,19 @@ def from_supported_tool( ) def to_langchain_tool(self) -> BaseTool: + """ Description + + Returns: + BaseTool: + """ return self.tool def to_llama_index_tool(self) -> LlamaIndex__BaseTool: + """ Description + + Returns: + LlamaIndex__BaseTool: + """ ensure_module_is_installed("llama_index") llama_index_tool = LlamaIndex__FunctionTool.from_defaults( fn=self.tool._run, @@ -78,6 +133,11 @@ def to_llama_index_tool(self) -> LlamaIndex__BaseTool: return llama_index_tool def to_autogen_tool(self): + """ Description + + Returns: + Callable: + """ fields = list(self.tool.args_schema.__fields__.values()) if len(fields) != 1: raise Exception("Multiple input fields are not supported in to_autogen_tool")