Skip to content

Commit

Permalink
Agents: turn any Space into a Tool with Tool.from_space() (huggingf…
Browse files Browse the repository at this point in the history
…ace#34561)

* Agents: you can now load a Space as a tool
  • Loading branch information
aymeric-roucher authored and BernardZach committed Dec 5, 2024
1 parent c4674ad commit fd60412
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 59 deletions.
2 changes: 1 addition & 1 deletion docs/source/ar/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ image = image_generator(prompt=improved_prompt)

قبل إنشاء الصورة أخيرًا:

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" />
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit_spacesuit_flux.webp" />

> [!WARNING]
> تتطلب gradio-tools إدخالات وإخراجات *نصية* حتى عند العمل مع طرائق مختلفة مثل كائنات الصور والصوت. الإدخالات والإخراجات الصورية والصوتية غير متوافقة حاليًا.
Expand Down
74 changes: 46 additions & 28 deletions docs/source/en/agents_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,52 +123,70 @@ from transformers import load_tool, CodeAgent
model_download_tool = load_tool("m-ric/hf-model-downloads")
```

### Use gradio-tools
### Import a Space as a tool 🚀

[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
Face Spaces as tools. It supports many existing Spaces as well as custom Spaces.
You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method!

Transformers supports `gradio_tools` with the [`Tool.from_gradio`] method. For example, let's use the [`StableDiffusionPromptGeneratorTool`](https://github.com/freddyaboulton/gradio-tools/blob/main/gradio_tools/tools/prompt_generator.py) from `gradio-tools` toolkit for improving prompts to generate better images.
You only need to provide the id of the Space on the Hub, its name, and a description that will help you agent understand what the tool does. Under the hood, this will use [`gradio-client`](https://pypi.org/project/gradio-client/) library to call the Space.

Import and instantiate the tool, then pass it to the `Tool.from_gradio` method:
For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image.

```python
from gradio_tools import StableDiffusionPromptGeneratorTool
from transformers import Tool, load_tool, CodeAgent
```
from transformers import Tool
gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
image_generation_tool = Tool.from_space(
"black-forest-labs/FLUX.1-dev",
name="image_generator",
description="Generate an image from a prompt")
image_generation_tool("A sunny beach")
```
And voilà, here's your image! 🏖️

Now you can use it just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit`.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sunny_beach.webp">

Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it.

```python
image_generation_tool = load_tool('huggingface-tools/text-to-image')
agent = CodeAgent(tools=[prompt_generator_tool, image_generation_tool], llm_engine=llm_engine)
from transformers import ReactCodeAgent

agent = ReactCodeAgent(tools=[image_generation_tool])

agent.run(
"Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
)
```

The model adequately leverages the tool:
```text
======== New task ========
Improve this prompt, then generate an image of it.
You have been provided with these initial arguments: {'prompt': 'A rabbit wearing a space suit'}.
==== Agent is executing the code below:
improved_prompt = StableDiffusionPromptGenerator(query=prompt)
while improved_prompt == "QUEUE_FULL":
improved_prompt = StableDiffusionPromptGenerator(query=prompt)
print(f"The improved prompt is {improved_prompt}.")
image = image_generator(prompt=improved_prompt)
====
=== Agent thoughts:
improved_prompt could be "A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background"
Now that I have improved the prompt, I can use the image generator tool to generate an image based on this prompt.
>>> Agent is executing the code below:
image = image_generator(prompt="A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background")
final_answer(image)
```

Before finally generating the image:
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit_spacesuit_flux.webp">

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
How cool is this? 🤩

### Use gradio-tools

[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
Face Spaces as tools. It supports many existing Spaces as well as custom Spaces.

Transformers supports `gradio_tools` with the [`Tool.from_gradio`] method. For example, let's use the [`StableDiffusionPromptGeneratorTool`](https://github.com/freddyaboulton/gradio-tools/blob/main/gradio_tools/tools/prompt_generator.py) from `gradio-tools` toolkit for improving prompts to generate better images.

Import and instantiate the tool, then pass it to the `Tool.from_gradio` method:

```python
from gradio_tools import StableDiffusionPromptGeneratorTool
from transformers import Tool, load_tool, CodeAgent

gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
```

> [!WARNING]
> gradio-tools require *textual* inputs and outputs even when working with different modalities like image and audio objects. Image and audio inputs and outputs are currently incompatible.
Expand All @@ -179,7 +197,7 @@ We love Langchain and think it has a very compelling suite of tools.
To import a tool from LangChain, use the `from_langchain()` method.

Here is how you can use it to recreate the intro's search result using a LangChain web search tool.

This tool will need `pip install google-search-results` to work properly.
```python
from langchain.agents import load_tools
from transformers import Tool, ReactCodeAgent
Expand All @@ -188,7 +206,7 @@ search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])

agent = ReactCodeAgent(tools=[search_tool])

agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
agent.run("How many more blocks (also denoted as layers) are in BERT base encoder compared to the encoder from the architecture proposed in Attention is All You Need?")
```

## Display your agent run in a cool Gradio interface
Expand Down
113 changes: 83 additions & 30 deletions src/transformers/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,22 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
"""


def validate_after_init(cls):
def validate_after_init(cls, do_validate_forward: bool = True):
original_init = cls.__init__

@wraps(original_init)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
if not isinstance(self, PipelineTool):
self.validate_arguments()
self.validate_arguments(do_validate_forward=do_validate_forward)

cls.__init__ = new_init
return cls


@validate_after_init
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}


class Tool:
"""
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
Expand Down Expand Up @@ -131,7 +133,11 @@ class Tool:
def __init__(self, *args, **kwargs):
self.is_initialized = False

def validate_arguments(self):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
validate_after_init(cls, do_validate_forward=False)

def validate_arguments(self, do_validate_forward: bool = True):
required_attributes = {
"description": str,
"name": str,
Expand All @@ -145,21 +151,23 @@ def validate_arguments(self):
if not isinstance(attr_value, expected_type):
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
for input_name, input_content in self.inputs.items():
assert "type" in input_content, f"Input '{input_name}' should specify a type."
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
assert (
"type" in input_content and "description" in input_content
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
if input_content["type"] not in authorized_types:
raise Exception(
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
)
assert "description" in input_content, f"Input '{input_name}' should have a description."

assert getattr(self, "output_type", None) in authorized_types

if not isinstance(self, PipelineTool):
signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)
if do_validate_forward:
if not isinstance(self, PipelineTool):
signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)

def forward(self, *args, **kwargs):
return NotImplemented("Write this method in your subclass of `Tool`.")
Expand Down Expand Up @@ -405,6 +413,58 @@ def push_to_hub(
repo_type="space",
)

@staticmethod
def from_space(space_id, name, description):
"""
Creates a [`Tool`] from a Space given its id on the Hub.
Args:
space_id (`str`):
The id of the Space on the Hub.
name (`str`):
The name of the tool.
description (`str`):
The description of the tool.
Returns:
[`Tool`]:
The created tool.
Example:
```
tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt")
```
"""
from gradio_client import Client

class SpaceToolWrapper(Tool):
def __init__(self, space_id, name, description):
self.client = Client(space_id)
self.name = name
self.description = description
space_description = self.client.view_api(return_format="dict")["named_endpoints"]
route = list(space_description.keys())[0]
space_description_route = space_description[route]
self.inputs = {}
for parameter in space_description_route["parameters"]:
if not parameter["parameter_has_default"]:
self.inputs[parameter["parameter_name"]] = {
"type": parameter["type"]["type"],
"description": parameter["python_type"]["description"],
}
output_component = space_description_route["returns"][0]["component"]
if output_component == "Image":
self.output_type = "image"
elif output_component == "Audio":
self.output_type = "audio"
else:
self.output_type = "any"

def forward(self, *args, **kwargs):
return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result

return SpaceToolWrapper(space_id, name, description)

@staticmethod
def from_gradio(gradio_tool):
"""
Expand All @@ -414,16 +474,15 @@ def from_gradio(gradio_tool):

class GradioToolWrapper(Tool):
def __init__(self, _gradio_tool):
super().__init__()
self.name = _gradio_tool.name
self.description = _gradio_tool.description
self.output_type = "string"
self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
self.inputs = {key: "" for key in func_args}

def forward(self, *args, **kwargs):
return self._gradio_tool.run(*args, **kwargs)
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
self.inputs = {
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
}
self.forward = self._gradio_tool.run

return GradioToolWrapper(gradio_tool)

Expand All @@ -435,10 +494,13 @@ def from_langchain(langchain_tool):

class LangChainToolWrapper(Tool):
def __init__(self, _langchain_tool):
super().__init__()
self.name = _langchain_tool.name.lower()
self.description = _langchain_tool.description
self.inputs = parse_langchain_args(_langchain_tool.args)
self.inputs = _langchain_tool.args.copy()
for input_content in self.inputs.values():
if "title" in input_content:
input_content.pop("title")
input_content["description"] = ""
self.output_type = "string"
self.langchain_tool = _langchain_tool

Expand Down Expand Up @@ -805,15 +867,6 @@ def __call__(
return response.json()


def parse_langchain_args(args: Dict[str, str]) -> Dict[str, str]:
"""Parse the args attribute of a LangChain tool to create a matching inputs dictionary."""
inputs = args.copy()
for arg_details in inputs.values():
if "title" in arg_details:
arg_details.pop("title")
return inputs


class ToolCollection:
"""
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
Expand Down

0 comments on commit fd60412

Please sign in to comment.