diff --git a/docs/source/ar/agents.md b/docs/source/ar/agents.md
index 92b2a4715f6..1213b350086 100644
--- a/docs/source/ar/agents.md
+++ b/docs/source/ar/agents.md
@@ -464,7 +464,7 @@ image = image_generator(prompt=improved_prompt)
قبل إنشاء الصورة أخيرًا:
-
+
> [!WARNING]
> تتطلب gradio-tools إدخالات وإخراجات *نصية* حتى عند العمل مع طرائق مختلفة مثل كائنات الصور والصوت. الإدخالات والإخراجات الصورية والصوتية غير متوافقة حاليًا.
diff --git a/docs/source/en/agents_advanced.md b/docs/source/en/agents_advanced.md
index ddcc619b4f9..e80e402d737 100644
--- a/docs/source/en/agents_advanced.md
+++ b/docs/source/en/agents_advanced.md
@@ -123,52 +123,70 @@ from transformers import load_tool, CodeAgent
model_download_tool = load_tool("m-ric/hf-model-downloads")
```
-### Use gradio-tools
+### Import a Space as a tool 🚀
-[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
-Face Spaces as tools. It supports many existing Spaces as well as custom Spaces.
+You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method!
-Transformers supports `gradio_tools` with the [`Tool.from_gradio`] method. For example, let's use the [`StableDiffusionPromptGeneratorTool`](https://github.com/freddyaboulton/gradio-tools/blob/main/gradio_tools/tools/prompt_generator.py) from `gradio-tools` toolkit for improving prompts to generate better images.
+You only need to provide the id of the Space on the Hub, its name, and a description that will help you agent understand what the tool does. Under the hood, this will use [`gradio-client`](https://pypi.org/project/gradio-client/) library to call the Space.
-Import and instantiate the tool, then pass it to the `Tool.from_gradio` method:
+For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image.
-```python
-from gradio_tools import StableDiffusionPromptGeneratorTool
-from transformers import Tool, load_tool, CodeAgent
+```
+from transformers import Tool
-gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
-prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
+image_generation_tool = Tool.from_space(
+ "black-forest-labs/FLUX.1-dev",
+ name="image_generator",
+ description="Generate an image from a prompt")
+
+image_generation_tool("A sunny beach")
```
+And voilà, here's your image! 🏖️
-Now you can use it just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit`.
+
+
+Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it.
```python
-image_generation_tool = load_tool('huggingface-tools/text-to-image')
-agent = CodeAgent(tools=[prompt_generator_tool, image_generation_tool], llm_engine=llm_engine)
+from transformers import ReactCodeAgent
+
+agent = ReactCodeAgent(tools=[image_generation_tool])
agent.run(
"Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
)
```
-The model adequately leverages the tool:
```text
-======== New task ========
-Improve this prompt, then generate an image of it.
-You have been provided with these initial arguments: {'prompt': 'A rabbit wearing a space suit'}.
-==== Agent is executing the code below:
-improved_prompt = StableDiffusionPromptGenerator(query=prompt)
-while improved_prompt == "QUEUE_FULL":
- improved_prompt = StableDiffusionPromptGenerator(query=prompt)
-print(f"The improved prompt is {improved_prompt}.")
-image = image_generator(prompt=improved_prompt)
-====
+=== Agent thoughts:
+improved_prompt could be "A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background"
+
+Now that I have improved the prompt, I can use the image generator tool to generate an image based on this prompt.
+>>> Agent is executing the code below:
+image = image_generator(prompt="A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background")
+final_answer(image)
```
-Before finally generating the image:
+
-
+How cool is this? 🤩
+### Use gradio-tools
+
+[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
+Face Spaces as tools. It supports many existing Spaces as well as custom Spaces.
+
+Transformers supports `gradio_tools` with the [`Tool.from_gradio`] method. For example, let's use the [`StableDiffusionPromptGeneratorTool`](https://github.com/freddyaboulton/gradio-tools/blob/main/gradio_tools/tools/prompt_generator.py) from `gradio-tools` toolkit for improving prompts to generate better images.
+
+Import and instantiate the tool, then pass it to the `Tool.from_gradio` method:
+
+```python
+from gradio_tools import StableDiffusionPromptGeneratorTool
+from transformers import Tool, load_tool, CodeAgent
+
+gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
+prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
+```
> [!WARNING]
> gradio-tools require *textual* inputs and outputs even when working with different modalities like image and audio objects. Image and audio inputs and outputs are currently incompatible.
@@ -179,7 +197,7 @@ We love Langchain and think it has a very compelling suite of tools.
To import a tool from LangChain, use the `from_langchain()` method.
Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
-
+This tool will need `pip install google-search-results` to work properly.
```python
from langchain.agents import load_tools
from transformers import Tool, ReactCodeAgent
@@ -188,7 +206,7 @@ search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
agent = ReactCodeAgent(tools=[search_tool])
-agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
+agent.run("How many more blocks (also denoted as layers) are in BERT base encoder compared to the encoder from the architecture proposed in Attention is All You Need?")
```
## Display your agent run in a cool Gradio interface
diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py
index a425ffc8f10..994e1bdd817 100644
--- a/src/transformers/agents/tools.py
+++ b/src/transformers/agents/tools.py
@@ -87,20 +87,22 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
"""
-def validate_after_init(cls):
+def validate_after_init(cls, do_validate_forward: bool = True):
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
if not isinstance(self, PipelineTool):
- self.validate_arguments()
+ self.validate_arguments(do_validate_forward=do_validate_forward)
cls.__init__ = new_init
return cls
-@validate_after_init
+CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
+
+
class Tool:
"""
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
@@ -131,7 +133,11 @@ class Tool:
def __init__(self, *args, **kwargs):
self.is_initialized = False
- def validate_arguments(self):
+ def __init_subclass__(cls, **kwargs):
+ super().__init_subclass__(**kwargs)
+ validate_after_init(cls, do_validate_forward=False)
+
+ def validate_arguments(self, do_validate_forward: bool = True):
required_attributes = {
"description": str,
"name": str,
@@ -145,21 +151,23 @@ def validate_arguments(self):
if not isinstance(attr_value, expected_type):
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
for input_name, input_content in self.inputs.items():
- assert "type" in input_content, f"Input '{input_name}' should specify a type."
+ assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
+ assert (
+ "type" in input_content and "description" in input_content
+ ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
if input_content["type"] not in authorized_types:
raise Exception(
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
)
- assert "description" in input_content, f"Input '{input_name}' should have a description."
assert getattr(self, "output_type", None) in authorized_types
-
- if not isinstance(self, PipelineTool):
- signature = inspect.signature(self.forward)
- if not set(signature.parameters.keys()) == set(self.inputs.keys()):
- raise Exception(
- "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
- )
+ if do_validate_forward:
+ if not isinstance(self, PipelineTool):
+ signature = inspect.signature(self.forward)
+ if not set(signature.parameters.keys()) == set(self.inputs.keys()):
+ raise Exception(
+ "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
+ )
def forward(self, *args, **kwargs):
return NotImplemented("Write this method in your subclass of `Tool`.")
@@ -405,6 +413,58 @@ def push_to_hub(
repo_type="space",
)
+ @staticmethod
+ def from_space(space_id, name, description):
+ """
+ Creates a [`Tool`] from a Space given its id on the Hub.
+
+ Args:
+ space_id (`str`):
+ The id of the Space on the Hub.
+ name (`str`):
+ The name of the tool.
+ description (`str`):
+ The description of the tool.
+
+ Returns:
+ [`Tool`]:
+ The created tool.
+
+ Example:
+ ```
+ tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt")
+ ```
+ """
+ from gradio_client import Client
+
+ class SpaceToolWrapper(Tool):
+ def __init__(self, space_id, name, description):
+ self.client = Client(space_id)
+ self.name = name
+ self.description = description
+ space_description = self.client.view_api(return_format="dict")["named_endpoints"]
+ route = list(space_description.keys())[0]
+ space_description_route = space_description[route]
+ self.inputs = {}
+ for parameter in space_description_route["parameters"]:
+ if not parameter["parameter_has_default"]:
+ self.inputs[parameter["parameter_name"]] = {
+ "type": parameter["type"]["type"],
+ "description": parameter["python_type"]["description"],
+ }
+ output_component = space_description_route["returns"][0]["component"]
+ if output_component == "Image":
+ self.output_type = "image"
+ elif output_component == "Audio":
+ self.output_type = "audio"
+ else:
+ self.output_type = "any"
+
+ def forward(self, *args, **kwargs):
+ return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result
+
+ return SpaceToolWrapper(space_id, name, description)
+
@staticmethod
def from_gradio(gradio_tool):
"""
@@ -414,16 +474,15 @@ def from_gradio(gradio_tool):
class GradioToolWrapper(Tool):
def __init__(self, _gradio_tool):
- super().__init__()
self.name = _gradio_tool.name
self.description = _gradio_tool.description
self.output_type = "string"
self._gradio_tool = _gradio_tool
- func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
- self.inputs = {key: "" for key in func_args}
-
- def forward(self, *args, **kwargs):
- return self._gradio_tool.run(*args, **kwargs)
+ func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
+ self.inputs = {
+ key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
+ }
+ self.forward = self._gradio_tool.run
return GradioToolWrapper(gradio_tool)
@@ -435,10 +494,13 @@ def from_langchain(langchain_tool):
class LangChainToolWrapper(Tool):
def __init__(self, _langchain_tool):
- super().__init__()
self.name = _langchain_tool.name.lower()
self.description = _langchain_tool.description
- self.inputs = parse_langchain_args(_langchain_tool.args)
+ self.inputs = _langchain_tool.args.copy()
+ for input_content in self.inputs.values():
+ if "title" in input_content:
+ input_content.pop("title")
+ input_content["description"] = ""
self.output_type = "string"
self.langchain_tool = _langchain_tool
@@ -805,15 +867,6 @@ def __call__(
return response.json()
-def parse_langchain_args(args: Dict[str, str]) -> Dict[str, str]:
- """Parse the args attribute of a LangChain tool to create a matching inputs dictionary."""
- inputs = args.copy()
- for arg_details in inputs.values():
- if "title" in arg_details:
- arg_details.pop("title")
- return inputs
-
-
class ToolCollection:
"""
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.