diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 84bcf0fde61f18..caad34dae7eb74 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -720,21 +720,22 @@ def launch_gradio_demo(tool_class: Tool): def fn(*args, **kwargs): return tool(*args, **kwargs) + TYPE_TO_COMPONENT_CLASS_MAPPING = { + "image": gr.Image, + "audio": gr.Audio, + "string": gr.Textbox, + "integer": gr.Textbox, + "number": gr.Textbox, + } + gradio_inputs = [] for input_name, input_details in tool_class.inputs.items(): - input_type = input_details["type"] - if input_type == "image": - gradio_inputs.append(gr.Image(label=input_name)) - elif input_type == "audio": - gradio_inputs.append(gr.Audio(label=input_name)) - elif input_type in ["string", "integer", "number"]: - gradio_inputs.append(gr.Textbox(label=input_name)) - else: - error_message = f"Input type '{input_type}' not supported." - raise ValueError(error_message) + input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]] + new_component = input_gradio_component_class(label=input_name) + gradio_inputs.append(new_component) - gradio_output = tool_class.output_type - assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported." + output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type] + gradio_output = output_gradio_componentclass(label=input_name) gr.Interface( fn=fn,