Skip to content

Commit

Permalink
Fix launch_gradio_demo
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Nov 11, 2024
1 parent 427db5f commit 8de9c43
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions src/transformers/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,21 +720,22 @@ def launch_gradio_demo(tool_class: Tool):
def fn(*args, **kwargs):
return tool(*args, **kwargs)

TYPE_TO_COMPONENT_CLASS_MAPPING = {
"image": gr.Image,
"audio": gr.Audio,
"string": gr.Textbox,
"integer": gr.Textbox,
"number": gr.Textbox,
}

gradio_inputs = []
for input_name, input_details in tool_class.inputs.items():
input_type = input_details["type"]
if input_type == "image":
gradio_inputs.append(gr.Image(label=input_name))
elif input_type == "audio":
gradio_inputs.append(gr.Audio(label=input_name))
elif input_type in ["string", "integer", "number"]:
gradio_inputs.append(gr.Textbox(label=input_name))
else:
error_message = f"Input type '{input_type}' not supported."
raise ValueError(error_message)
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
new_component = input_gradio_component_class(label=input_name)
gradio_inputs.append(new_component)

gradio_output = tool_class.output_type
assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported."
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
gradio_output = output_gradio_componentclass(label=input_name)

gr.Interface(
fn=fn,
Expand Down

0 comments on commit 8de9c43

Please sign in to comment.