Skip to content

Commit

Permalink
Add RemBERT ONNX support (#2108)
Browse files Browse the repository at this point in the history
* ONNX config for RemBERT added

* added RemBERT to TasksManager

* rembert added to exporters_utils

* RemBERT added to test modelling tasks

* changed rembert model

* added RemBERT to test utils

* Added RemBERT to documentation

* Apply suggestions from code review

---------

Co-authored-by: Ilyas Moutawwakil <[email protected]>
  • Loading branch information
mlynatom and IlyasMoutawwakil authored Dec 2, 2024
1 parent 28bd0ad commit f22655c
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- PoolFormer
- Qwen2(Qwen1.5)
- RegNet
- RemBERT
- ResNet
- Roberta
- Roformer
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ class SplinterOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class RemBertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class DistilBertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for transformers>=4.46.0

Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,15 @@ class TasksManager:
onnx="BertOnnxConfig",
tflite="BertTFLiteConfig",
),
"rembert": supported_tasks_mapping(
"fill-mask",
"feature-extraction",
"text-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx="RemBertOnnxConfig",
),
# For big-bird and bigbird-pegasus being unsupported, refer to model_configs.py
# "big-bird": supported_tasks_mapping(
# "feature-extraction",
Expand Down
3 changes: 2 additions & 1 deletion tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
"phi3": "Xenova/tiny-random-Phi3ForCausalLM",
"pix2struct": "fxmarty/pix2struct-tiny-random",
# "rembert": "google/rembert",
"rembert": "hf-internal-testing/tiny-random-RemBertModel",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen2": "fxmarty/tiny-dummy-qwen2",
"regnet": "hf-internal-testing/tiny-random-RegNetModel",
Expand Down Expand Up @@ -257,7 +258,7 @@
"owlv2": "google/owlv2-base-patch16",
"owlvit": "google/owlvit-base-patch32",
"perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing.
# "rembert": "google/rembert",
"rembert": "google/rembert",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"regnet": "facebook/regnet-y-040",
"resnet": "microsoft/resnet-50",
Expand Down
5 changes: 5 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,7 @@ class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
"squeezebert",
"xlm_qa",
"xlm_roberta",
"rembert",
]

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
Expand Down Expand Up @@ -1502,6 +1503,7 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin):
"squeezebert",
"xlm",
"xlm_roberta",
"rembert",
]

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
Expand Down Expand Up @@ -1682,6 +1684,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
"squeezebert",
"xlm",
"xlm_roberta",
"rembert",
]

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
Expand Down Expand Up @@ -1882,6 +1885,7 @@ class ORTModelForTokenClassificationIntegrationTest(ORTModelTestMixin):
"squeezebert",
"xlm",
"xlm_roberta",
"rembert",
]

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
Expand Down Expand Up @@ -2227,6 +2231,7 @@ class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin):
"squeezebert",
"xlm",
"xlm_roberta",
"rembert",
]

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
"pix2struct": "fxmarty/pix2struct-tiny-random",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen2": "fxmarty/tiny-dummy-qwen2",
"rembert": "hf-internal-testing/tiny-random-RemBertModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
Expand Down

0 comments on commit f22655c

Please sign in to comment.