Skip to content

Commit

Permalink
Allow handling files as args for a tool created with Tool.from_space
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Nov 11, 2024
1 parent 25f510a commit aba1bc4
Showing 1 changed file with 53 additions and 15 deletions.
68 changes: 53 additions & 15 deletions src/transformers/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import tempfile
from functools import lru_cache, wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
Expand Down Expand Up @@ -414,7 +415,7 @@ def push_to_hub(
)

@staticmethod
def from_space(space_id, name, description):
def from_space(space_id: str, name: str, description: str, api_name: Optional[str] = None):
"""
Creates a [`Tool`] from a Space given its id on the Hub.
Expand All @@ -425,34 +426,63 @@ def from_space(space_id, name, description):
The name of the tool.
description (`str`):
The description of the tool.
api_name (`str`, *optional*):
The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
Returns:
[`Tool`]:
The created tool.
The Space, as a tool.
Example:
Examples:
```
tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt")
image_generator = Tool.from_space(
space_id="black-forest-labs/FLUX.1-schnell",
name="image-generator",
description="Generate an image from a prompt"
)
image = image_generator("Generate an image of a cool surfer in Tahiti")
```
```
face_swapper = Tool.from_space(
"tuan2308/face-swap",
"face_swapper",
"Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
)
image = face_swapper('./aymeric.jpeg', './ruth.jpg')
```
"""
from gradio_client import Client
from gradio_client import Client, handle_file
from gradio_client.utils import is_http_url_like

class SpaceToolWrapper(Tool):
def __init__(self, space_id, name, description):
def __init__(self, space_id: str, name: str, description: str, api_name: Optional[str] = None):
self.client = Client(space_id)
self.name = name
self.description = description
space_description = self.client.view_api(return_format="dict")["named_endpoints"]
route = list(space_description.keys())[0]
space_description_route = space_description[route]
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]

# If api_name is not defined, take the first of the available APIs for this space
if api_name is None:
api_name = list(space_description.keys())[0]
logger.warning(f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`.")
self.api_name = api_name

try:
space_description_api = space_description[api_name]
except KeyError:
raise KeyError(f"Could not find specified {api_name=} among available api names.")

self.inputs = {}
for parameter in space_description_route["parameters"]:
for parameter in space_description_api["parameters"]:
if not parameter["parameter_has_default"]:
parameter_type = parameter["type"]["type"]
if parameter_type == "object":
parameter_type = "any"
self.inputs[parameter["parameter_name"]] = {
"type": parameter["type"]["type"],
"type": parameter_type,
"description": parameter["python_type"]["description"],
}
output_component = space_description_route["returns"][0]["component"]
output_component = space_description_api["returns"][0]["component"]
if output_component == "Image":
self.output_type = "image"
elif output_component == "Audio":
Expand All @@ -461,9 +491,17 @@ def __init__(self, space_id, name, description):
self.output_type = "any"

def forward(self, *args, **kwargs):
return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result

return SpaceToolWrapper(space_id, name, description)
# Test if any arg is a file and processes it accordingly:
args = list(args)
for i, arg in enumerate(args):
if (
isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()
) or is_http_url_like(arg):
args[i] = handle_file(arg)
output = self.client.predict(*args, api_name=self.api_name, **kwargs)
return output[0] # Usually the first output is the result

return SpaceToolWrapper(space_id, name, description, api_name=api_name)

@staticmethod
def from_gradio(gradio_tool):
Expand Down

0 comments on commit aba1bc4

Please sign in to comment.