Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose more params in main_export for customized ONNX export #2100

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 69 additions & 20 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from packaging import version
from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import is_torch_available
from transformers.utils import is_torch_available, is_tf_available

from ...commands.export.onnx import parse_args_onnx
from ...configuration_utils import _transformers_version
Expand All @@ -36,17 +36,30 @@
if is_torch_available():
import torch

from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
if is_tf_available():
import tensorflow as tf

from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, List


if TYPE_CHECKING:
from .base import OnnxConfig
from ...utils.import_utils import is_onnxruntime_available, is_diffusers_available
if is_onnxruntime_available():
from onnxruntime import SessionOptions
if is_torch_available():
from transformers.modeling_utils import PreTrainedModel
if is_diffusers_available():
from diffusers import DiffusionPipeline
if is_tf_available():
from transformers.modeling_tf_utils import TFPreTrainedModel

logger = logging.get_logger()


def main_export(
model_name_or_path: str,
model_name_or_path: Union[str, "PreTrainedModel", "TFPreTrainedModel",
"DiffusionPipeline"],
output: Union[str, Path],
task: str = "auto",
opset: Optional[int] = None,
Expand Down Expand Up @@ -78,6 +91,10 @@ def main_export(
legacy: bool = False,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
disable_dynamic_axes_fix: bool = False,
custom_export_fn: Optional[Callable[..., None]] = None,
providers: Optional[List[str]] = None,
session_options: Optional["SessionOptions"] = None,
**kwargs_shapes,
):
"""
Expand All @@ -86,8 +103,9 @@ def main_export(
Args:
> Required parameters

model_name_or_path (`str`):
model_name_or_path (`Union[str, "PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]`):
Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="BAAI/bge-m3"` or `mode_name_or_path="/path/to/model_folder`.
It is also possible to pass a model object to skip getting models from the export task.
output (`Union[str, Path]`):
Path indicating the directory where to store the generated ONNX model.

Expand Down Expand Up @@ -166,6 +184,14 @@ def main_export(
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`):
Whether to disable the default dynamic axes fixing.
custom_export_fn (`Optional[Callable[..., None]]`, defaults to `None`):
Customized PyTorch ONNX export function. If `None` provided, `torch.onnx.export` will be used.
providers (`Optional[List[str]]`, defaults to `None`):
ONNXRuntime execution provides used for the dynamic axis fix and the model validation. If `None` provided, it will be determined by the `device` param.
session_options (`Optional["SessionOptions"]`, defaults to `None`):
ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

Expand Down Expand Up @@ -226,6 +252,24 @@ def main_export(
"Please use one of the following tasks instead: `text-to-image`, `image-to-image`, `inpainting`."
)

if isinstance(model_name_or_path, str):
model = None
model_name_or_path = model_name_or_path
else:
model = model_name_or_path
model_name_or_path = model.config._name_or_path

if providers is None:
if device.startswith("cuda"):
if (is_torch_available()
and torch.version.hip) or (is_tf_available()
and tf.test.is_built_with_rocm()):
providers = ["ROCMExecutionProvider"]
else:
providers = ["CUDAExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]

original_task = task
task = TasksManager.map_from_synonym(task)

Expand Down Expand Up @@ -300,22 +344,23 @@ def main_export(
if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and _transformers_version >= version.parse("4.35.99"):
loading_kwargs["attn_implementation"] = "eager"

model = TasksManager.get_model_from_task(
task,
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
torch_dtype=torch_dtype,
device=device,
library_name=library_name,
**loading_kwargs,
)
if model is None:
model = TasksManager.get_model_from_task(
task,
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
torch_dtype=torch_dtype,
device=device,
library_name=library_name,
**loading_kwargs,
)

needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None

Expand Down Expand Up @@ -390,6 +435,10 @@ def main_export(
task=task,
use_subprocess=use_subprocess,
do_constant_folding=do_constant_folding,
disable_dynamic_axes_fix=disable_dynamic_axes_fix,
custom_export_fn=custom_export_fn,
providers=providers,
session_options=session_options,
**kwargs_shapes,
)

Expand Down
28 changes: 19 additions & 9 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
if is_diffusers_available():
from diffusers import ModelMixin

if is_onnxruntime_available():
from onnxruntime import SessionOptions

from .model_patcher import PatchingSpec

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -269,9 +272,12 @@ def variant(self, value: str):
raise ValueError(f"The variant {value} is not supported for the ONNX config {self.__class__.__name__}.")
self._variant = value

def fix_dynamic_axes(
self, model_path: "Path", device: str = "cpu", dtype: Optional[str] = None, input_shapes: Optional[Dict] = None
):
def fix_dynamic_axes(self,
model_path: "Path",
providers: List[str],
session_options: Optional["SessionOptions"] = None,
dtype: Optional[str] = None,
input_shapes: Optional[Dict] = None):
"""
Fixes potential issues with dynamic axes.
During the export, ONNX will infer some axes to be dynamic which are actually static. This method is called
Expand All @@ -280,6 +286,10 @@ def fix_dynamic_axes(
Args:
model_path (`Path`):
The path of the freshly exported ONNX model.
providers (`Optional[List[str]]`):
ONNXRuntime execution provides used for the dynamic axis fix and the model validation.
session_options (`Optional["SessionOptions"]`, defaults to `None`):
ONNXRuntime session options used for the dynamic axis fix and the model validation. If `None` provided, a default `SessionOptions` object will be created.
"""
if not (is_onnx_available() and is_onnxruntime_available()):
raise RuntimeError(
Expand All @@ -296,12 +306,10 @@ def fix_dynamic_axes(
for output in self.outputs.values():
allowed_dynamic_axes |= set(output.values())

if device.startswith("cuda"):
providers = ["CUDAExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]

session_options = SessionOptions()
if session_options is None:
session_options = SessionOptions()
# backup the original optimization level
original_opt_level = session_options.graph_optimization_level
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL # no need to optimize here
session = InferenceSession(model_path.as_posix(), providers=providers, sess_options=session_options)

Expand Down Expand Up @@ -349,6 +357,8 @@ def fix_dynamic_axes(
)
del onnx_model
gc.collect()
# restore the original optimization level
session_options.graph_optimization_level = original_opt_level

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
Expand Down
Loading