Skip to content

Commit

Permalink
Add ONNX export support for ModernBERT (#2131)
Browse files Browse the repository at this point in the history
* Introduce `DisableCompileContextManager`

* DisableCompileContextManager definition

* Add ONNX export support for modernbert

* Always use `DisableCompileContextManager` during export

* Add modernbert to listed models

* Add modernbert unit tests
  • Loading branch information
xenova authored Jan 7, 2025
1 parent fea8e44 commit 72498dd
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- MobileVit
- MobileNet v1
- MobileNet v2
- ModernBert
- MPNet
- MT5
- Musicgen (text-conditional only)
Expand Down
34 changes: 18 additions & 16 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...utils import DEFAULT_DUMMY_SHAPES, logging
from ...utils.save_utils import maybe_load_preprocessors
from ..tasks import TasksManager
from ..utils import DisableCompileContextManager
from .constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
from .convert import onnx_export_from_model

Expand Down Expand Up @@ -300,22 +301,23 @@ def main_export(
if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and _transformers_version >= version.parse("4.35.99"):
loading_kwargs["attn_implementation"] = "eager"

model = TasksManager.get_model_from_task(
task,
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
torch_dtype=torch_dtype,
device=device,
library_name=library_name,
**loading_kwargs,
)
with DisableCompileContextManager():
model = TasksManager.get_model_from_task(
task,
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
torch_dtype=torch_dtype,
device=device,
library_name=library_name,
**loading_kwargs,
)

needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None

Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return {"input_ids": dynamic_axis, "attention_mask": dynamic_axis}


class ModernBertOnnxConfig(DistilBertOnnxConfig):
pass


class MPNetOnnxConfig(DistilBertOnnxConfig):
DEFAULT_ONNX_OPSET = 12 # For lower opsets, results in: Type 'tensor(int64)' of input parameter (/0/auto_model/encoder/Add_1_output_0) of operator (Min) in node (/0/auto_model/encoder/Min) is invalid.

Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,15 @@ class TasksManager:
"image-classification",
onnx="MobileNetV2OnnxConfig",
),
"modernbert": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
"text-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx="ModernBertOnnxConfig",
),
"mpnet": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down
12 changes: 12 additions & 0 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,15 @@ def check_dummy_inputs_are_allowed(
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)


class DisableCompileContextManager:
def __init__(self):
self._original_compile = torch.compile

def __enter__(self):
# Turn torch.compile into a no-op
torch.compile = lambda *args, **kwargs: lambda x: x

def __exit__(self, exc_type, exc_val, exc_tb):
torch.compile = self._original_compile
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilenet-v1": "google/mobilenet_v1_0.75_192",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mt5": "lewtun/tiny-random-mt5",
Expand Down Expand Up @@ -266,6 +267,7 @@
# "mobilenet_v1": "google/mobilenet_v1_0.75_192",
# "mobilenet_v2": "google/mobilenet_v2_0.35_96",
"mobilevit": "apple/mobilevit-small",
"modernbert": "answerdotai/ModernBERT-base",
"mpt": "mosaicml/mpt-7b",
"mt5": "lewtun/tiny-random-mt5", # Not using google/mt5-small because it takes too much time for testing.
"musicgen": "facebook/musicgen-small",
Expand Down

0 comments on commit 72498dd

Please sign in to comment.