From b789134caa4029b27753f05073993064ae1b8090 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 21 Nov 2024 09:44:48 +0000 Subject: [PATCH 1/5] feat: expose more params in `main_export` for customized export --- optimum/exporters/onnx/__main__.py | 89 +++++++++++++++++------ optimum/exporters/onnx/base.py | 28 +++++--- optimum/exporters/onnx/convert.py | 112 ++++++++++++++++++++++++++--- 3 files changed, 190 insertions(+), 39 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 6a2cc6834a6..f481a495357 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -22,7 +22,7 @@ from packaging import version from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoConfig, AutoTokenizer -from transformers.utils import is_torch_available +from transformers.utils import is_torch_available, is_tf_available from ...commands.export.onnx import parse_args_onnx from ...configuration_utils import _transformers_version @@ -36,17 +36,30 @@ if is_torch_available(): import torch -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +if is_tf_available(): + import tensorflow as tf + +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, List if TYPE_CHECKING: from .base import OnnxConfig + from ...utils.import_utils import is_onnxruntime_available, is_diffusers_available + if is_onnxruntime_available(): + from onnxruntime import SessionOptions + if is_torch_available(): + from transformers.modeling_utils import PreTrainedModel + if is_diffusers_available(): + from diffusers import DiffusionPipeline + if is_tf_available(): + from transformers.modeling_tf_utils import TFPreTrainedModel logger = logging.get_logger() def main_export( - model_name_or_path: str, + model_name_or_path_or_obj: Union[str, "PreTrainedModel", + "TFPreTrainedModel", "DiffusionPipeline"], output: Union[str, Path], task: str = "auto", opset: Optional[int] = None, @@ -78,6 +91,10 @@ def main_export( legacy: bool = False, no_dynamic_axes: bool = False, do_constant_folding: bool = True, + disable_dynamic_axes_fix: bool = False, + custom_export_fn: Optional[Callable[..., None]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["SessionOptions"] = None, **kwargs_shapes, ): """ @@ -86,8 +103,9 @@ def main_export( Args: > Required parameters - model_name_or_path (`str`): + model_name_or_path_or_obj (`Union[str, "PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]`): Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="BAAI/bge-m3"` or `mode_name_or_path="/path/to/model_folder`. + It is also possible to pass a model object to skip getting models from the export task. output (`Union[str, Path]`): Path indicating the directory where to store the generated ONNX model. @@ -166,6 +184,14 @@ def main_export( If True, disables the use of dynamic axes during ONNX export. do_constant_folding (bool, defaults to `True`): PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible. + disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`): + Whether to disable the default dynamic axes fixing. + custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + providers (`Optional[List[str]]`, defaults to `None`): + ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. + session_options (`Optional["SessionOptions"]`, defaults to `None`): + ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created. **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -226,6 +252,24 @@ def main_export( "Please use one of the following tasks instead: `text-to-image`, `image-to-image`, `inpainting`." ) + if isinstance(model_name_or_path_or_obj, str): + model = None + model_name_or_path = model_name_or_path_or_obj + else: + model = model_name_or_path_or_obj + model_name_or_path = model.config._name_or_path + + if providers is None: + if device.startswith("cuda"): + if (is_torch_available() + and torch.version.hip) or (is_tf_available() + and tf.test.is_built_with_rocm()): + providers = ["ROCMExecutionProvider"] + else: + providers = ["CUDAExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + original_task = task task = TasksManager.map_from_synonym(task) @@ -300,22 +344,23 @@ def main_export( if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and _transformers_version >= version.parse("4.35.99"): loading_kwargs["attn_implementation"] = "eager" - model = TasksManager.get_model_from_task( - task, - model_name_or_path, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - token=token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - framework=framework, - torch_dtype=torch_dtype, - device=device, - library_name=library_name, - **loading_kwargs, - ) + if model is None: + model = TasksManager.get_model_from_task( + task, + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + framework=framework, + torch_dtype=torch_dtype, + device=device, + library_name=library_name, + **loading_kwargs, + ) needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None @@ -390,6 +435,10 @@ def main_export( task=task, use_subprocess=use_subprocess, do_constant_folding=do_constant_folding, + disable_dynamic_axes_fix=disable_dynamic_axes_fix, + custom_export_fn=custom_export_fn, + providers=providers, + session_options=session_options, **kwargs_shapes, ) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 7e35691d54b..8a065c86a01 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -63,6 +63,9 @@ if is_diffusers_available(): from diffusers import ModelMixin + if is_onnxruntime_available(): + from onnxruntime import SessionOptions + from .model_patcher import PatchingSpec logger = logging.get_logger(__name__) @@ -269,9 +272,12 @@ def variant(self, value: str): raise ValueError(f"The variant {value} is not supported for the ONNX config {self.__class__.__name__}.") self._variant = value - def fix_dynamic_axes( - self, model_path: "Path", device: str = "cpu", dtype: Optional[str] = None, input_shapes: Optional[Dict] = None - ): + def fix_dynamic_axes(self, + model_path: "Path", + providers: List[str], + session_options: Optional["SessionOptions"] = None, + dtype: Optional[str] = None, + input_shapes: Optional[Dict] = None): """ Fixes potential issues with dynamic axes. During the export, ONNX will infer some axes to be dynamic which are actually static. This method is called @@ -280,6 +286,10 @@ def fix_dynamic_axes( Args: model_path (`Path`): The path of the freshly exported ONNX model. + providers (`Optional[List[str]]`): + ONNXRuntime execution provides used for the dynamic axis fix and the model validation. + session_options (`Optional["SessionOptions"]`, defaults to `None`): + ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created. """ if not (is_onnx_available() and is_onnxruntime_available()): raise RuntimeError( @@ -296,12 +306,10 @@ def fix_dynamic_axes( for output in self.outputs.values(): allowed_dynamic_axes |= set(output.values()) - if device.startswith("cuda"): - providers = ["CUDAExecutionProvider"] - else: - providers = ["CPUExecutionProvider"] - - session_options = SessionOptions() + if session_options is None: + session_options = SessionOptions() + # backup the original optimization level + original_opt_level = session_options.graph_optimization_level session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL # no need to optimize here session = InferenceSession(model_path.as_posix(), providers=providers, sess_options=session_options) @@ -349,6 +357,8 @@ def fix_dynamic_axes( ) del onnx_model gc.collect() + # restore the original optimization level + session_options.graph_optimization_level = original_opt_level def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index c12a9ac222a..37452a54716 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -22,7 +22,7 @@ from inspect import signature from itertools import chain from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, TYPE_CHECKING import numpy as np import onnx @@ -65,8 +65,13 @@ from diffusers import DiffusionPipeline, ModelMixin if is_tf_available(): + import tensorflow as tf from transformers.modeling_tf_utils import TFPreTrainedModel +if TYPE_CHECKING: + from ...utils.import_utils import is_onnxruntime_available + if is_onnxruntime_available(): + from onnxruntime import SessionOptions logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -111,6 +116,8 @@ def validate_models_outputs( device: str = "cpu", use_subprocess: Optional[bool] = True, model_kwargs: Optional[Dict[str, Any]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["SessionOptions"] = None, ): """ Validates the export of several models, by checking that the outputs from both the reference and the exported model match. @@ -137,6 +144,10 @@ def validate_models_outputs( model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during the export and validation. + providers (`Optional[List[str]]`, defaults to `None`): + ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. + session_options (`Optional["SessionOptions"]`, defaults to `None`): + ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created. Raises: ValueError: If the outputs shapes or values do not match between the reference and the exported model. """ @@ -174,6 +185,8 @@ def validate_models_outputs( device=device, use_subprocess=use_subprocess, model_kwargs=model_kwargs, + providers=providers, + session_options=session_options, ) except Exception as e: exceptions.append((onnx_model_path, e)) @@ -194,6 +207,8 @@ def validate_model_outputs( device: str = "cpu", use_subprocess: Optional[bool] = True, model_kwargs: Optional[Dict[str, Any]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["SessionOptions"] = None, ): """ Validates the export by checking that the outputs from both the reference and the exported model match. @@ -217,6 +232,10 @@ def validate_model_outputs( model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during the export and validation. + providers (`Optional[List[str]]`, defaults to `None`): + ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. + session_options (`Optional["SessionOptions"]`, defaults to `None`): + ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created. Raises: ValueError: If the outputs shapes or values do not match between the reference and the exported model. """ @@ -243,18 +262,23 @@ def validate_model_outputs( input_shapes, device, model_kwargs=model_kwargs, + providers=providers, + session_options=session_options, ) def _run_validation( config: OnnxConfig, - reference_model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + reference_model: Union["PreTrainedModel", "TFPreTrainedModel", + "ModelMixin"], onnx_model: Path, onnx_named_outputs: List[str], atol: Optional[float] = None, input_shapes: Optional[Dict] = None, device: str = "cpu", model_kwargs: Optional[Dict[str, Any]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["SessionOptions"] = None, ): from onnxruntime import GraphOptimizationLevel, SessionOptions @@ -275,16 +299,25 @@ def _run_validation( reference_model_inputs = config.generate_dummy_inputs(framework=framework, **input_shapes) # Create ONNX Runtime session - session_options = SessionOptions() + if session_options is None: + session_options = SessionOptions() + # backup the original optimization level + original_opt_level = session_options.graph_optimization_level # We could well set ORT_DISABLE_ALL here, but it makes CUDA export with O4 of gpt_neo fail session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC - if device.startswith("cuda"): - provider = "CUDAExecutionProvider" - else: - provider = "CPUExecutionProvider" + if providers is None: + if device.startswith("cuda"): + if (is_torch_available() + and torch.version.hip) or (is_tf_available() + and tf.test.is_built_with_rocm()): + providers = ["ROCMExecutionProvider"] + else: + providers = ["CUDAExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] - session = PickableInferenceSession(onnx_model.as_posix(), sess_options=session_options, providers=[provider]) + session = PickableInferenceSession(onnx_model.as_posix(), sess_options=session_options, providers=providers) # Sometimes the exported model can have more outputs than what is specified in the ONNX config because the original # PyTorch model has more outputs that were forgotten in the config, so we check for that. @@ -372,6 +405,9 @@ def _run_validation( # Compute outputs from the ONNX model onnx_outputs = session.run(onnx_named_outputs, onnx_inputs) + # restore the original optimization level + session_options.graph_optimization_level = original_opt_level + # Modify the ONNX output names to match the reference model output names onnx_to_torch = {v: k for k, v in config.torch_to_onnx_output_map.items()} onnx_named_outputs = [onnx_to_torch.get(k, k) for k in onnx_named_outputs] @@ -446,6 +482,8 @@ def __init__( input_shapes: Optional[Dict] = None, device: str = "cpu", model_kwargs: Optional[Dict[str, Any]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["SessionOptions"] = None, ): super().__init__() self._pconn, self._cconn = mp.Pipe() @@ -458,6 +496,8 @@ def __init__( self.input_shapes = input_shapes self.device = device self.model_kwargs = model_kwargs + self.providers = providers + self.session_options = session_options def run(self): try: @@ -470,6 +510,8 @@ def run(self): input_shapes=self.input_shapes, device=self.device, model_kwargs=self.model_kwargs, + providers=self.providers, + session_options=self.session_options, ) except Exception as e: tb = traceback.format_exc() @@ -493,6 +535,7 @@ def export_pytorch( no_dynamic_axes: bool = False, do_constant_folding: bool = True, model_kwargs: Optional[Dict[str, Any]] = None, + custom_export_fn: Optional[Callable[..., None]] = None ) -> Tuple[List[str], List[str]]: """ Exports a PyTorch model to an ONNX Intermediate Representation. @@ -520,12 +563,17 @@ def export_pytorch( the export. This argument should be used along the `custom_onnx_config` argument in case, for example, the model inputs/outputs are changed (for example, if `model_kwargs={"output_attentions": True}` is passed). + custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from the ONNX configuration. """ - from torch.onnx import export as onnx_export + if custom_export_fn is not None: + onnx_export = custom_export_fn + else: + from torch.onnx import export as onnx_export from torch.utils._pytree import tree_map logger.info(f"Using framework PyTorch: {torch.__version__}") @@ -724,6 +772,9 @@ def export_models( no_dynamic_axes: bool = False, do_constant_folding: bool = True, model_kwargs: Optional[Dict[str, Any]] = None, + custom_export_fn: Optional[Callable[..., None]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["SessionOptions"] = None, ) -> Tuple[List[List[str]], List[List[str]]]: """ Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation. @@ -758,6 +809,12 @@ def export_models( the export. This argument should be used along the `custom_onnx_config` argument in case, for example, the model inputs/outputs are changed (for example, if `model_kwargs={"output_attentions": True}` is passed). + custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + providers (`Optional[List[str]]`, defaults to `None`): + ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. + session_options (`Optional["SessionOptions"]`, defaults to `None`): + ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created. Returns: `Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named outputs from the ONNX configuration. @@ -792,6 +849,9 @@ def export_models( no_dynamic_axes=no_dynamic_axes, do_constant_folding=do_constant_folding, model_kwargs=model_kwargs, + custom_export_fn=custom_export_fn, + providers=providers, + session_options=session_options, ) ) @@ -811,6 +871,9 @@ def export( no_dynamic_axes: bool = False, do_constant_folding: bool = True, model_kwargs: Optional[Dict[str, Any]] = None, + custom_export_fn: Optional[Callable[..., None]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["SessionOptions"] = None, ) -> Tuple[List[str], List[str]]: """ Exports a Pytorch or TensorFlow model to an ONNX Intermediate Representation. @@ -842,6 +905,12 @@ def export( the export. This argument should be used along the `custom_onnx_config` argument in case, for example, the model inputs/outputs are changed (for example, if `model_kwargs={"output_attentions": True}` is passed). + custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + providers (`Optional[List[str]]`, defaults to `None`): + ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. + session_options (`Optional["SessionOptions"]`, defaults to `None`): + ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created. Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from @@ -895,6 +964,7 @@ def export( no_dynamic_axes=no_dynamic_axes, do_constant_folding=do_constant_folding, model_kwargs=model_kwargs, + custom_export_fn=custom_export_fn ) elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): @@ -914,7 +984,11 @@ def export( ) if not disable_dynamic_axes_fix: - config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype) + config.fix_dynamic_axes(output, + providers=providers, + session_options=session_options, + input_shapes=input_shapes, + dtype=dtype) return export_output @@ -938,6 +1012,10 @@ def onnx_export_from_model( task: Optional[str] = None, use_subprocess: bool = False, do_constant_folding: bool = True, + disable_dynamic_axes_fix: bool = False, + custom_export_fn: Optional[Callable[..., None]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["SessionOptions"] = None, **kwargs_shapes, ): """ @@ -993,6 +1071,14 @@ def onnx_export_from_model( If True, disables the use of dynamic axes during ONNX export. do_constant_folding (bool, defaults to `True`): PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible. + disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`): + Whether to disable the default dynamic axes fixing. + custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + providers (`Optional[List[str]]`, defaults to `None`): + ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. + session_options (`Optional["SessionOptions"]`, defaults to `None`): + ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created. **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -1205,6 +1291,10 @@ def onnx_export_from_model( no_dynamic_axes=no_dynamic_axes, do_constant_folding=do_constant_folding, model_kwargs=model_kwargs, + disable_dynamic_axes_fix=disable_dynamic_axes_fix, + custom_export_fn=custom_export_fn, + providers=providers, + session_options=session_options, ) if optimize is not None: @@ -1254,6 +1344,8 @@ def onnx_export_from_model( device=device, use_subprocess=use_subprocess, model_kwargs=model_kwargs, + providers=providers, + session_options=session_options, ) logger.info(f"The ONNX export succeeded and the exported model was saved at: {output.as_posix()}") except ShapeError as e: From 5c8e6a8cdba63aa7d0304dcbf89f8d245d4430c0 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 21 Nov 2024 10:49:26 +0000 Subject: [PATCH 2/5] fix: typo --- optimum/exporters/onnx/__main__.py | 2 +- optimum/exporters/onnx/convert.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index f481a495357..92477c7979e 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -187,7 +187,7 @@ def main_export( disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`): Whether to disable the default dynamic axes fixing. custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): - Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` will be used. providers (`Optional[List[str]]`, defaults to `None`): ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. session_options (`Optional["SessionOptions"]`, defaults to `None`): diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 37452a54716..3c5ef3cbd34 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -564,7 +564,7 @@ def export_pytorch( in case, for example, the model inputs/outputs are changed (for example, if `model_kwargs={"output_attentions": True}` is passed). custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): - Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` will be used. Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from @@ -810,7 +810,7 @@ def export_models( in case, for example, the model inputs/outputs are changed (for example, if `model_kwargs={"output_attentions": True}` is passed). custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): - Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` will be used. providers (`Optional[List[str]]`, defaults to `None`): ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. session_options (`Optional["SessionOptions"]`, defaults to `None`): @@ -906,7 +906,7 @@ def export( in case, for example, the model inputs/outputs are changed (for example, if `model_kwargs={"output_attentions": True}` is passed). custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): - Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` will be used. providers (`Optional[List[str]]`, defaults to `None`): ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. session_options (`Optional["SessionOptions"]`, defaults to `None`): @@ -1074,7 +1074,7 @@ def onnx_export_from_model( disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`): Whether to disable the default dynamic axes fixing. custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`): - Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` is be used. + Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` will be used. providers (`Optional[List[str]]`, defaults to `None`): ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param. session_options (`Optional["SessionOptions"]`, defaults to `None`): From c10f7637e0164a323070f9c844966d693fd62aaa Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 22 Nov 2024 03:34:51 +0000 Subject: [PATCH 3/5] test: pytorch custom export fn --- tests/exporters/onnx/test_onnx_export.py | 164 ++++++++++++++++++----- 1 file changed, 129 insertions(+), 35 deletions(-) diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 88288547c95..ca2c40fdc4c 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -17,8 +17,8 @@ from functools import partial from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict -from unittest import TestCase +from typing import Callable, Dict, Optional, List +from unittest import TestCase, mock import onnx import pytest @@ -43,7 +43,7 @@ from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED from optimum.exporters.onnx.model_configs import WhisperOnnxConfig from optimum.exporters.onnx.utils import get_speecht5_models_for_export -from optimum.utils import DummyPastKeyValuesGenerator, NormalizedTextConfig +from optimum.utils import DummyPastKeyValuesGenerator, NormalizedTextConfig, is_onnxruntime_available from optimum.utils.save_utils import maybe_load_preprocessors from optimum.utils.testing_utils import grid_parameters, require_diffusers @@ -60,6 +60,8 @@ if is_torch_available() or is_tf_available(): from optimum.exporters.tasks import TasksManager +if is_onnxruntime_available(): + import onnxruntime as ort SEED = 42 @@ -179,6 +181,10 @@ def _onnx_export( shapes_to_validate: Dict, monolith: bool, device="cpu", + do_validation: bool = True, + custom_export_fn: Optional[Callable[..., None]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["ort.SessionOptions"] = None, ): library_name = TasksManager.infer_library_from_model(model_name) @@ -256,39 +262,44 @@ def _onnx_export( output_dir=Path(tmpdirname), device=device, model_kwargs=model_kwargs, + custom_export_fn=custom_export_fn, + providers=providers, + session_options=session_options, ) - input_shapes_iterator = grid_parameters(shapes_to_validate, yield_dict=True, add_test_name=False) - for input_shapes in input_shapes_iterator: - skip = False - for _, model_onnx_conf in models_and_onnx_configs.items(): - if ( - hasattr(model_onnx_conf[0].config, "max_position_embeddings") - and input_shapes["sequence_length"] >= model_onnx_conf[0].config.max_position_embeddings - ): - skip = True - break - if ( - model_type == "groupvit" - and input_shapes["sequence_length"] - >= model_onnx_conf[0].config.text_config.max_position_embeddings - ): - skip = True - break - if skip: - continue - - try: - validate_models_outputs( - models_and_onnx_configs=models_and_onnx_configs, - onnx_named_outputs=onnx_outputs, - atol=atol, - output_dir=Path(tmpdirname), - input_shapes=input_shapes, - device=device, - model_kwargs=model_kwargs, - ) - except AtolError as e: - print(f"The ONNX export succeeded with the warning: {e}") + + if do_validation: + input_shapes_iterator = grid_parameters(shapes_to_validate, yield_dict=True, add_test_name=False) + for input_shapes in input_shapes_iterator: + skip = False + for _, model_onnx_conf in models_and_onnx_configs.items(): + if ( + hasattr(model_onnx_conf[0].config, "max_position_embeddings") + and input_shapes["sequence_length"] >= model_onnx_conf[0].config.max_position_embeddings + ): + skip = True + break + if ( + model_type == "groupvit" + and input_shapes["sequence_length"] + >= model_onnx_conf[0].config.text_config.max_position_embeddings + ): + skip = True + break + if skip: + continue + + try: + validate_models_outputs( + models_and_onnx_configs=models_and_onnx_configs, + onnx_named_outputs=onnx_outputs, + atol=atol, + output_dir=Path(tmpdirname), + input_shapes=input_shapes, + device=device, + model_kwargs=model_kwargs, + ) + except AtolError as e: + print(f"The ONNX export succeeded with the warning: {e}") gc.collect() @@ -347,6 +358,44 @@ def test_pytorch_export_on_cpu( monolith=monolith, ) + @require_torch + @require_vision + @slow + @pytest.mark.run_slow + def test_pytorch_customized_export_on_cpu(self): + from torch.onnx import export as pytorch_export + + providers = ["CPUExecutionProvider"] + + test_name, model_type, model_name, task, onnx_config_class_constructor, monolith = _get_models_to_test( + { + "beit": + "hf-internal-testing/tiny-random-BeitForImageClassification" + })[0] + + if is_onnxruntime_available(): + do_validation = True + so = ort.SessionOptions() + else: + do_validation = False + so = None + + custom_export_mock = mock.MagicMock(side_effect=pytorch_export) + self._onnx_export( + test_name, + model_type, + model_name, + task, + onnx_config_class_constructor, + shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES_SLOW, + monolith=monolith, + custom_export_fn=custom_export_mock, + providers=providers, + session_options=so, + do_validation=do_validation, + ) + custom_export_mock.assert_called_once() + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_torch @require_vision @@ -377,6 +426,51 @@ def test_pytorch_export_on_cuda( monolith=monolith, ) + @require_torch + @require_vision + @require_torch_gpu + @slow + @pytest.mark.run_slow + @pytest.mark.gpu_test + def test_pytorch_customized_export_on_cuda(self): + import torch.version + from torch.onnx import export as pytorch_export + + if torch.version.hip: + providers = ["ROCMExecutionProvider"] + else: + providers = ["CUDAExecutionProvider"] + + test_name, model_type, model_name, task, onnx_config_class_constructor, monolith = _get_models_to_test( + { + "beit": + "hf-internal-testing/tiny-random-BeitForImageClassification" + })[0] + + if is_onnxruntime_available(): + do_validation = True + so = ort.SessionOptions() + else: + do_validation = False + so = None + + custom_export_mock = mock.MagicMock(side_effect=pytorch_export) + self._onnx_export( + test_name, + model_type, + model_name, + task, + onnx_config_class_constructor, + device="cuda", + shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES_SLOW, + monolith=monolith, + custom_export_fn=custom_export_mock, + providers=providers, + session_options=so, + do_validation=do_validation, + ) + custom_export_mock.assert_called_once() + @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_MODELS)) @slow @pytest.mark.run_slow From b4aa01a3addfa91ea885d3b237e0f50635423cd4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 22 Nov 2024 06:52:11 +0000 Subject: [PATCH 4/5] test: main_export with model object and custom ops --- optimum/exporters/onnx/__main__.py | 12 +- .../exporters/onnx/test_exporters_onnx_cli.py | 151 +++++++++++++++++- 2 files changed, 153 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 92477c7979e..54261d8c50f 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -58,8 +58,8 @@ def main_export( - model_name_or_path_or_obj: Union[str, "PreTrainedModel", - "TFPreTrainedModel", "DiffusionPipeline"], + model_name_or_path: Union[str, "PreTrainedModel", "TFPreTrainedModel", + "DiffusionPipeline"], output: Union[str, Path], task: str = "auto", opset: Optional[int] = None, @@ -103,7 +103,7 @@ def main_export( Args: > Required parameters - model_name_or_path_or_obj (`Union[str, "PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]`): + model_name_or_path (`Union[str, "PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]`): Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="BAAI/bge-m3"` or `mode_name_or_path="/path/to/model_folder`. It is also possible to pass a model object to skip getting models from the export task. output (`Union[str, Path]`): @@ -252,11 +252,11 @@ def main_export( "Please use one of the following tasks instead: `text-to-image`, `image-to-image`, `inpainting`." ) - if isinstance(model_name_or_path_or_obj, str): + if isinstance(model_name_or_path, str): model = None - model_name_or_path = model_name_or_path_or_obj + model_name_or_path = model_name_or_path else: - model = model_name_or_path_or_obj + model = model_name_or_path model_name_or_path = model.config._name_or_path if providers is None: diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 9ac7832aa7d..9701f3b3900 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -15,14 +15,16 @@ import os import subprocess import unittest +from unittest import mock +import itertools from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Optional +from typing import Dict, Optional, Union, TYPE_CHECKING, List, Callable import onnx import pytest from parameterized import parameterized -from transformers import AutoModelForSequenceClassification, AutoTokenizer, is_torch_available +from transformers import AutoModelForSequenceClassification, AutoTokenizer, is_torch_available, AutoModel from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow from optimum.exporters.error_utils import MinimumVersionError @@ -33,12 +35,16 @@ ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME, ) +from optimum.utils.import_utils import is_onnxruntime_available from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_sentence_transformers, require_timm if is_torch_available(): from optimum.exporters.tasks import TasksManager +if is_onnxruntime_available(): + import onnxruntime as ort + from ..exporters_utils import ( NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS, PYTORCH_DIFFUSION_MODEL, @@ -49,6 +55,15 @@ PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES, ) +if TYPE_CHECKING: + from optimum.utils.import_utils import is_diffusers_available + from transformers import is_tf_available + from transformers.modeling_utils import PreTrainedModel + if is_diffusers_available(): + from diffusers import DiffusionPipeline + if is_tf_available(): + from transformers.modeling_tf_utils import TFPreTrainedModel + def _get_models_to_test(export_models_dict: Dict, library_name: str): models_to_test = [] @@ -174,7 +189,8 @@ class OnnxCLIExportTestCase(unittest.TestCase): def _onnx_export( self, - model_name: str, + model_name: Union[str, "PreTrainedModel", "TFPreTrainedModel", + "DiffusionPipeline"], task: str, monolith: bool = False, no_post_process: bool = False, @@ -184,6 +200,11 @@ def _onnx_export( variant: str = "default", no_dynamic_axes: bool = False, model_kwargs: Optional[Dict] = None, + do_validation: bool = True, + disable_dynamic_axes_fix: bool = False, + custom_export_fn: Optional[Callable[..., None]] = None, + providers: Optional[List[str]] = None, + session_options: Optional["ort.SessionOptions"] = None, ): # We need to set this to some value to be able to test the outputs values for batch size > 1. if task == "text-classification": @@ -206,13 +227,19 @@ def _onnx_export( no_dynamic_axes=no_dynamic_axes, pad_token_id=pad_token_id, model_kwargs=model_kwargs, + do_validation=do_validation, + disable_dynamic_axes_fix=disable_dynamic_axes_fix, + custom_export_fn=custom_export_fn, + providers=providers, + session_options=session_options, ) except MinimumVersionError as e: pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}") def _onnx_export_no_dynamic_axes( self, - model_name: str, + model_name: Union[str, "PreTrainedModel", "TFPreTrainedModel", + "DiffusionPipeline"], task: str, input_shape: dict, input_shape_for_validation: tuple, @@ -739,3 +766,119 @@ def test_complex_synonyms(self): model.save_pretrained(tmpdir_in) main_export(model_name_or_path=tmpdir_in, output=tmpdir_out, task="text-classification") + + @parameterized.expand(itertools.product([False, True], ["cuda", "cpu"])) + @require_vision + @require_torch_gpu + @slow + @pytest.mark.run_slow + def test_customized_export( + self, + use_custom_op: bool, + device: str, + ): + import torch.version + from torch import nn + from torch.autograd import Function + from torch.onnx import export as pytorch_export, symbolic_helper + + class CustomActivationFunc(Function): + @staticmethod + def forward(ctx, input_tensor): + return input_tensor + + @staticmethod + def backward(ctx, grad_outputs: torch.Tensor): + return grad_outputs + + @staticmethod + @symbolic_helper.parse_args("v") + def symbolic(g, input_tensor): + ret = g.op('CustomDomain::CustomActivation', input_tensor) + ret.setType(input_tensor.type()) + return ret + + class CustomActivation(nn.Module): + def forward(self, input_tensor): + return CustomActivationFunc.apply(input_tensor) + + def replace_activation(model: nn.Module): + if hasattr(model, "intermediate_act_fn"): + setattr(model, "intermediate_act_fn", CustomActivation()) + for child in model.children(): + replace_activation(child) + + test_name, model_type, model_name, task, variant, monolith, no_post_process = _get_models_to_test( + { + "beit": + "hf-internal-testing/tiny-random-BeitForImageClassification" + }, + library_name="transformers")[0] + + if device.startswith("cuda"): + if torch.version.hip: + providers = ["ROCMExecutionProvider"] + else: + providers = ["CUDAExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + if is_onnxruntime_available(): + do_validation = True + so = ort.SessionOptions() + else: + do_validation = False + so = None + + custom_export_mock = mock.MagicMock(side_effect=pytorch_export) + + model = AutoModel.from_pretrained(model_name).to(device) + TasksManager.standardize_model_attributes(model) + + if use_custom_op: + replace_activation(model) + if do_validation: + with pytest.raises(Exception): + # this one will fail because no custom ops are registered in onnxruntime + self._onnx_export( + model, + task, + monolith, + no_post_process, + variant=variant, + device=model.device, + disable_dynamic_axes_fix=not do_validation, + do_validation=do_validation, + custom_export_fn=custom_export_mock, + providers=providers, + session_options=so, + ) + self._onnx_export( + model, + task, + monolith, + no_post_process, + variant=variant, + device=model.device, + disable_dynamic_axes_fix=True, + do_validation=False, + custom_export_fn=custom_export_mock, + providers=providers, + session_options=so, + ) + else: + self._onnx_export( + model, + task, + monolith, + no_post_process, + variant=variant, + device=model.device, + disable_dynamic_axes_fix=not do_validation, + do_validation=do_validation, + custom_export_fn=custom_export_mock, + providers=providers, + session_options=so, + ) + + custom_export_mock.assert_called() From a6251e0d94396674ad7c376058bfb784a5a97a15 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 22 Nov 2024 07:23:28 +0000 Subject: [PATCH 5/5] test: passing `device` in str type --- tests/exporters/onnx/test_exporters_onnx_cli.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 9701f3b3900..df984a0a086 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -846,7 +846,7 @@ def replace_activation(model: nn.Module): monolith, no_post_process, variant=variant, - device=model.device, + device=device, disable_dynamic_axes_fix=not do_validation, do_validation=do_validation, custom_export_fn=custom_export_mock, @@ -859,7 +859,7 @@ def replace_activation(model: nn.Module): monolith, no_post_process, variant=variant, - device=model.device, + device=device, disable_dynamic_axes_fix=True, do_validation=False, custom_export_fn=custom_export_mock, @@ -873,7 +873,7 @@ def replace_activation(model: nn.Module): monolith, no_post_process, variant=variant, - device=model.device, + device=device, disable_dynamic_axes_fix=not do_validation, do_validation=do_validation, custom_export_fn=custom_export_mock,