diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index b5129c23f2..18c75953c3 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -74,6 +74,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - MobileVit - MobileNet v1 - MobileNet v2 +- ModernBert - MPNet - MT5 - Musicgen (text-conditional only) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 6a2cc6834a..280c6fc655 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -29,6 +29,7 @@ from ...utils import DEFAULT_DUMMY_SHAPES, logging from ...utils.save_utils import maybe_load_preprocessors from ..tasks import TasksManager +from ..utils import DisableCompileContextManager from .constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED from .convert import onnx_export_from_model @@ -300,22 +301,23 @@ def main_export( if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and _transformers_version >= version.parse("4.35.99"): loading_kwargs["attn_implementation"] = "eager" - model = TasksManager.get_model_from_task( - task, - model_name_or_path, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - token=token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - framework=framework, - torch_dtype=torch_dtype, - device=device, - library_name=library_name, - **loading_kwargs, - ) + with DisableCompileContextManager(): + model = TasksManager.get_model_from_task( + task, + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + framework=framework, + torch_dtype=torch_dtype, + device=device, + library_name=library_name, + **loading_kwargs, + ) needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 3a48a579c2..8966a1b1a3 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -181,6 +181,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]: return {"input_ids": dynamic_axis, "attention_mask": dynamic_axis} +class ModernBertOnnxConfig(DistilBertOnnxConfig): + pass + + class MPNetOnnxConfig(DistilBertOnnxConfig): DEFAULT_ONNX_OPSET = 12 # For lower opsets, results in: Type 'tensor(int64)' of input parameter (/0/auto_model/encoder/Add_1_output_0) of operator (Min) in node (/0/auto_model/encoder/Min) is invalid. diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 7cb5a31d2d..59c066ac38 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -891,6 +891,15 @@ class TasksManager: "image-classification", onnx="MobileNetV2OnnxConfig", ), + "modernbert": supported_tasks_mapping( + "feature-extraction", + "fill-mask", + "text-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx="ModernBertOnnxConfig", + ), "mpnet": supported_tasks_mapping( "feature-extraction", "fill-mask", diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 02b1d0fe3a..58e170ba97 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -704,3 +704,15 @@ def check_dummy_inputs_are_allowed( raise ValueError( f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}" ) + + +class DisableCompileContextManager: + def __init__(self): + self._original_compile = torch.compile + + def __enter__(self): + # Turn torch.compile into a no-op + torch.compile = lambda *args, **kwargs: lambda x: x + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.compile = self._original_compile diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 900b5f3b5c..d256e16dd4 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -125,6 +125,7 @@ "mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model", "mobilenet-v1": "google/mobilenet_v1_0.75_192", "mobilevit": "hf-internal-testing/tiny-random-mobilevit", + "modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM", "mpnet": "hf-internal-testing/tiny-random-MPNetModel", "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "lewtun/tiny-random-mt5", @@ -266,6 +267,7 @@ # "mobilenet_v1": "google/mobilenet_v1_0.75_192", # "mobilenet_v2": "google/mobilenet_v2_0.35_96", "mobilevit": "apple/mobilevit-small", + "modernbert": "answerdotai/ModernBERT-base", "mpt": "mosaicml/mpt-7b", "mt5": "lewtun/tiny-random-mt5", # Not using google/mt5-small because it takes too much time for testing. "musicgen": "facebook/musicgen-small",