Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Nov 4, 2024
1 parent c5f36f4 commit e9ea3ba
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/transformers/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def new_init(self, *args, **kwargs):
cls.__init__ = new_init
return cls


CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}


class Tool:
"""
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
Expand Down Expand Up @@ -135,7 +137,6 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
validate_after_init(cls, do_validate_forward=False)


def validate_arguments(self, do_validate_forward: bool = True):
required_attributes = {
"description": str,
Expand All @@ -151,7 +152,9 @@ def validate_arguments(self, do_validate_forward: bool = True):
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
for input_name, input_content in self.inputs.items():
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
assert "type" in input_content and "description" in input_content, f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
assert (
"type" in input_content and "description" in input_content
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
if input_content["type"] not in authorized_types:
raise Exception(
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
Expand Down Expand Up @@ -409,7 +412,7 @@ def push_to_hub(
create_pr=create_pr,
repo_type="space",
)

@staticmethod
def from_space(space_id, name, description):
"""
Expand Down Expand Up @@ -439,9 +442,7 @@ def __init__(self, space_id, name, description):
self.client = Client(space_id)
self.name = name
self.description = description
space_description = self.client.view_api(return_format="dict")[
"named_endpoints"
]
space_description = self.client.view_api(return_format="dict")["named_endpoints"]
route = list(space_description.keys())[0]
space_description_route = space_description[route]
self.inputs = {}
Expand All @@ -460,10 +461,9 @@ def __init__(self, space_id, name, description):
self.output_type = "any"

def forward(self, *args, **kwargs):
return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result

return SpaceToolWrapper(space_id, name, description)
return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result

return SpaceToolWrapper(space_id, name, description)

@staticmethod
def from_gradio(gradio_tool):
Expand All @@ -479,7 +479,9 @@ def __init__(self, _gradio_tool):
self.output_type = "string"
self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
self.inputs = {key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args}
self.inputs = {
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
}
self.forward = self._gradio_tool.run

return GradioToolWrapper(gradio_tool)
Expand Down

0 comments on commit e9ea3ba

Please sign in to comment.