diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index fbe7b42c44f..5de73bb36f0 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -46,6 +46,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - ESM - Falcon - Flaubert +- GIT - GPT-2 - GPT-BigCode - GPT-J diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 1c838408807..3cf16b82949 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -2499,3 +2499,23 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + + +class GITOnnxConfig(VisionOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "input_ids": {0: "text_batch_size", 1: "sequence_length"}, + "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"} + } + + +class GITVisionModelOnnxConfig(VisionOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 4db4130302d..cac2fc8c4f9 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -680,6 +680,17 @@ class TasksManager: "text-classification", onnx="GemmaOnnxConfig", ), + "git": supported_tasks_mapping( + "feature-extraction", + "image-text-to-text", + "image-to-text", + onnx="GITOnnxConfig", + ), + "git-vision-model": supported_tasks_mapping( + "feature-extraction", + "image-to-text", + onnx="GITVisionModelOnnxConfig", + ), "glpn": supported_tasks_mapping( "feature-extraction", "depth-estimation", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index e04a850bc8c..14ed253ae94 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -95,6 +95,14 @@ }, "flaubert": "hf-internal-testing/tiny-random-flaubert", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", + "git": { + "hf-internal-testing/tiny-random-GitModel": [ + "feature-extraction", + ], + "hf-internal-testing/tiny-random-GitForCausalLM": [ + "image-text-to-text", + ], + }, "glpn": "hf-internal-testing/tiny-random-GLPNModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index c33c07fc7b1..2f4225ce876 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -100,6 +100,14 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", "flux": "optimum-internal-testing/tiny-random-flux", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", + "git": { + "hf-internal-testing/tiny-random-GitModel": [ + "feature-extraction", + ], + "hf-internal-testing/tiny-random-GitForCausalLM": [ + "image-text-to-text", + ], + }, "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",