diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index cc752779d30..0b0edd04a99 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -117,6 +117,24 @@ def inputs(self) -> Dict[str, Dict[int, str]]: "token_type_ids": dynamic_axis, } +class VisualBertOnnxConfig(TextAndVisionOnnxConfig): + DEFAULT_ONNX_OPSET = 11 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + "pooler_output": {0: "batch_size"}, + } + class AlbertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index fdc8bfcb539..3630e5fa95a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1108,6 +1108,12 @@ class TasksManager: "text-to-audio", onnx="VitsOnnxConfig", ), + "visualbert": supported_tasks_mapping( + "multiple-choice", + "question-answering", + "image-to-text", + onnx="VisualBertOnnxConfig", + ), "wavlm": supported_tasks_mapping( "feature-extraction", "automatic-speech-recognition", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index ccccb5510bf..87cbed4f331 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -197,6 +197,7 @@ "document-question-answering-with-past", ], }, + "visualbert": "hf-internal-testing/tiny-random-VisualBertModel", } @@ -286,6 +287,7 @@ "speech-to-text": "codenamewei/speech-to-text", "xlm": "xlm-clm-ende-1024", "xlm-roberta": "Unbabel/xlm-roberta-comet-small", + "visualbert": "unc-nlp/visualbert-uncased", } TENSORFLOW_EXPORT_MODELS = {