From 7680b3e3b3775f97ff98fa5d6d2d6976e384a2c7 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 4 Dec 2024 09:25:29 +0400 Subject: [PATCH 1/8] move check_dummy_inputs_allowed to common export utils --- optimum/exporters/onnx/convert.py | 27 ++------------------------- optimum/exporters/utils.py | 27 ++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index c12a9ac222a..0d4c544cd3a 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -22,7 +22,7 @@ from inspect import signature from itertools import chain from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import onnx @@ -45,6 +45,7 @@ from ...utils.save_utils import maybe_save_preprocessors from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError from ..tasks import TasksManager +from ..utils import check_dummy_inputs_are_allowed from .base import OnnxConfig from .constants import UNPICKABLE_ARCHS from .model_configs import SpeechT5OnnxConfig @@ -75,30 +76,6 @@ class DynamicAxisNameError(ValueError): pass -def check_dummy_inputs_are_allowed( - model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str] -): - """ - Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`. - Args: - model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]): - The model instance. - model_inputs (`Iterable[str]`): - The model input names. - """ - - forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call - forward_parameters = signature(forward).parameters - forward_inputs_set = set(forward_parameters.keys()) - dummy_input_names = set(dummy_input_names) - - # We are fine if config_inputs has more keys than model_inputs - if not dummy_input_names.issubset(forward_inputs_set): - raise ValueError( - f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}" - ) - - def validate_models_outputs( models_and_onnx_configs: Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"] diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 60de169de5e..59e053ee444 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -16,7 +16,8 @@ """Utilities for model preparation to export.""" import copy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from inspect import signature +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from packaging import version @@ -675,3 +676,27 @@ def _get_submodels_and_export_configs( export_config = next(iter(models_and_export_configs.values()))[1] return export_config, models_and_export_configs + + +def check_dummy_inputs_are_allowed( + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str] +): + """ + Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`. + Args: + model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]): + The model instance. + model_inputs (`Iterable[str]`): + The model input names. + """ + + forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call + forward_parameters = signature(forward).parameters + forward_inputs_set = set(forward_parameters.keys()) + dummy_input_names = set(dummy_input_names) + + # We are fine if config_inputs has more keys than model_inputs + if not dummy_input_names.issubset(forward_inputs_set): + raise ValueError( + f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}" + ) \ No newline at end of file From d5ceb674582a1c2135e7cd5406cff2e4200febc6 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 4 Dec 2024 10:52:47 +0400 Subject: [PATCH 2/8] move decoder_merge import --- optimum/exporters/onnx/model_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index b39d19ec782..12cf9c6108b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -21,7 +21,6 @@ from packaging import version from transformers.utils import is_tf_available -from ...onnx import merge_decoders from ...utils import ( DEFAULT_DUMMY_SHAPES, BloomDummyPastKeyValuesGenerator, @@ -1875,6 +1874,7 @@ def post_process_exported_models( decoder_with_past_path = Path(path, onnx_files_subpaths[3]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") try: + from ...onnx import merge_decoders # The decoder with past does not output the cross attention past key values as they are constant, # hence the need for strict=False merge_decoders( From 9b3f29c73c3b96deeac93bf93f555ce99737da96 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:22:49 +0100 Subject: [PATCH 3/8] Update optimum/exporters/utils.py --- optimum/exporters/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 59e053ee444..d386c09b7d8 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -690,7 +690,7 @@ def check_dummy_inputs_are_allowed( The model input names. """ - forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call + forward = model.forward if is_torch_available() and isinstance(model, torch.nn.Module) else model.call forward_parameters = signature(forward).parameters forward_inputs_set = set(forward_parameters.keys()) dummy_input_names = set(dummy_input_names) From f9e4615b273cdb10c5735dbc21a54d2fb2b03d75 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:22:49 +0100 Subject: [PATCH 4/8] Update optimum/exporters/utils.py --- optimum/exporters/onnx/model_configs.py | 1 + optimum/exporters/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 12cf9c6108b..fdb69317277 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1875,6 +1875,7 @@ def post_process_exported_models( decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") try: from ...onnx import merge_decoders + # The decoder with past does not output the cross attention past key values as they are constant, # hence the need for strict=False merge_decoders( diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 59e053ee444..d4a4111075d 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -690,7 +690,7 @@ def check_dummy_inputs_are_allowed( The model input names. """ - forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call + forward = model.forward if is_torch_available() and isinstance(model, torch.nn.Module) else model.call forward_parameters = signature(forward).parameters forward_inputs_set = set(forward_parameters.keys()) dummy_input_names = set(dummy_input_names) @@ -699,4 +699,4 @@ def check_dummy_inputs_are_allowed( if not dummy_input_names.issubset(forward_inputs_set): raise ValueError( f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}" - ) \ No newline at end of file + ) From deb4ec35e1660c609cdceb7abc7bc1ee211135fd Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 4 Dec 2024 17:21:09 +0400 Subject: [PATCH 5/8] avoid onnx import if not necessary --- optimum/exporters/onnx/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 7e35691d54b..f1b6ad1f376 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -27,16 +27,12 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np -import onnx from transformers.utils import is_accelerate_available, is_torch_available -from ...onnx import remove_duplicate_weights_from_tied_info - if is_torch_available(): import torch.nn as nn -from ...onnx import merge_decoders from ...utils import ( DEFAULT_DUMMY_SHAPES, DummyInputGenerator, @@ -315,6 +311,8 @@ def fix_dynamic_axes( # We branch here to avoid doing an unnecessary forward pass. if to_fix: + import onnx + if input_shapes is None: input_shapes = {} dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes) @@ -542,6 +540,10 @@ def post_process_exported_models( first_key = next(iter(models_and_onnx_configs)) if is_torch_available() and isinstance(models_and_onnx_configs[first_key][0], nn.Module): if is_accelerate_available(): + import onnx + + from ...onnx import remove_duplicate_weights_from_tied_info + logger.info("Deduplicating shared (tied) weights...") for subpath, key in zip(onnx_files_subpaths, models_and_onnx_configs): torch_model = models_and_onnx_configs[key][0] @@ -934,6 +936,8 @@ def post_process_exported_models( decoder_with_past_path = Path(path, onnx_files_subpaths[2]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") try: + from ...onnx import merge_decoders + # The decoder with past does not output the cross attention past key values as they are constant, # hence the need for strict=False merge_decoders( From be3c79cd1a418f61b1720cd121ffa4a1af6e0a58 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 18 Dec 2024 16:46:55 +0400 Subject: [PATCH 6/8] move merge decoders import --- optimum/exporters/onnx/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 9e808e392b9..43709ed32f8 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union from transformers.utils import is_tf_available - -from ...onnx import merge_decoders from ...utils import ( DummyAudioInputGenerator, DummyBboxInputGenerator, @@ -129,6 +127,7 @@ def post_process_exported_models( # Attempt to merge only if the decoder-only was exported separately without/with past if self.use_past is True and len(models_and_onnx_configs) == 2: + from ...onnx import merge_decoders decoder_path = Path(path, onnx_files_subpaths[0]) decoder_with_past_path = Path(path, onnx_files_subpaths[1]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") From f3962cad1b0ab1d6286b1b57fd7ee858eca7bc5d Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 19 Dec 2024 08:47:00 +0400 Subject: [PATCH 7/8] fix style --- optimum/exporters/onnx/base.py | 2 -- optimum/exporters/onnx/config.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index f1b6ad1f376..137e024ce77 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -311,8 +311,6 @@ def fix_dynamic_axes( # We branch here to avoid doing an unnecessary forward pass. if to_fix: - import onnx - if input_shapes is None: input_shapes = {} dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes) diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 43709ed32f8..d4e0630171b 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union from transformers.utils import is_tf_available + from ...utils import ( DummyAudioInputGenerator, DummyBboxInputGenerator, @@ -128,6 +129,7 @@ def post_process_exported_models( # Attempt to merge only if the decoder-only was exported separately without/with past if self.use_past is True and len(models_and_onnx_configs) == 2: from ...onnx import merge_decoders + decoder_path = Path(path, onnx_files_subpaths[0]) decoder_with_past_path = Path(path, onnx_files_subpaths[1]) decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") From 055297edb53c3b81b7d476d4b83f2e01f9691edc Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 19 Dec 2024 11:13:15 +0100 Subject: [PATCH 8/8] add comment --- optimum/exporters/onnx/base.py | 2 ++ optimum/exporters/onnx/config.py | 3 +++ optimum/exporters/onnx/convert.py | 2 ++ optimum/exporters/onnx/model_configs.py | 3 +++ 4 files changed, 10 insertions(+) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 137e024ce77..b5adb4522a2 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -50,6 +50,8 @@ from .model_patcher import ModelPatcher, Seq2SeqModelPatcher +# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization + if is_accelerate_available(): from accelerate.utils import find_tied_parameters diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index d4e0630171b..69366d6be13 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -37,6 +37,9 @@ from .model_patcher import DecoderModelPatcher +# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization + + if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 0d4c544cd3a..80d945580c7 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -57,6 +57,8 @@ ) +# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization + if is_torch_available(): import torch import torch.nn as nn diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index fdb69317277..315fced395a 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -92,6 +92,9 @@ ) +# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization + + if TYPE_CHECKING: from transformers import PretrainedConfig from transformers.modeling_utils import PreTrainedModel