Skip to content

Commit

Permalink
Add ONNX export support for GIT
Browse files Browse the repository at this point in the history
  • Loading branch information
marcindulak committed Dec 19, 2024
1 parent 4daa408 commit 9696a7f
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- ESM
- Falcon
- Flaubert
- GIT
- GPT-2
- GPT-BigCode
- GPT-J
Expand Down
20 changes: 20 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2499,3 +2499,23 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig

DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.


class GITOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"}
}


class GITVisionModelOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
11 changes: 11 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,17 @@ class TasksManager:
"text-classification",
onnx="GemmaOnnxConfig",
),
"git": supported_tasks_mapping(
"feature-extraction",
"image-text-to-text",
"image-to-text",
onnx="GITOnnxConfig",
),
"git-vision-model": supported_tasks_mapping(
"feature-extraction",
"image-to-text",
onnx="GITVisionModelOnnxConfig",
),
"glpn": supported_tasks_mapping(
"feature-extraction",
"depth-estimation",
Expand Down
8 changes: 8 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@
},
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
"git": {
"hf-internal-testing/tiny-random-GitModel": [
"feature-extraction",
],
"hf-internal-testing/tiny-random-GitForCausalLM": [
"image-text-to-text",
],
},
"glpn": "hf-internal-testing/tiny-random-GLPNModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
Expand Down
8 changes: 8 additions & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"flux": "optimum-internal-testing/tiny-random-flux",
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
"git": {
"hf-internal-testing/tiny-random-GitModel": [
"feature-extraction",
],
"hf-internal-testing/tiny-random-GitForCausalLM": [
"image-text-to-text",
],
},
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
Expand Down

0 comments on commit 9696a7f

Please sign in to comment.