From 54d3bc7aefd55270ad468e93c71433b93793260c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 21 Sep 2023 16:18:15 +0200 Subject: [PATCH] working export --- optimum/commands/export/onnx.py | 7 +++ optimum/exporters/onnx/__main__.py | 2 +- optimum/exporters/onnx/base.py | 8 +++- optimum/exporters/onnx/convert.py | 39 ++++++++--------- optimum/exporters/onnx/model_configs.py | 57 ++++++++++++++++++++++--- optimum/exporters/onnx/model_patcher.py | 22 +++++++--- optimum/exporters/onnx/utils.py | 19 +++++++-- optimum/utils/input_generators.py | 1 + 8 files changed, 116 insertions(+), 39 deletions(-) diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index d496f6f0392..a9ccae15375 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -14,6 +14,7 @@ """Defines the command line for the export with ONNX.""" import argparse +import json from pathlib import Path from typing import TYPE_CHECKING @@ -143,6 +144,11 @@ def parse_args_onnx(parser): "Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum." ), ) + optional_group.add_argument( + "--model-kwargs", + type=json.loads, + help=("Any kwargs passed to the model forward, or used to customize the export for a given model."), + ) input_group = parser.add_argument_group( "Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)." @@ -256,5 +262,6 @@ def run(self): _variant=self.args.variant, library_name=self.args.library_name, no_position_ids=self.args.no_position_ids, + model_kwargs=self.args.model_kwargs, **input_shapes, ) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 8b93d487dde..34c85173f80 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -97,7 +97,7 @@ def _get_submodels_and_onnx_configs( onnx_config.variant = _variant all_variants = "\n".join( - [f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()] + [f" - {name}: {description}" for name, description in onnx_config.VARIANTS.items()] ) logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}") diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 1e2ae99955c..ff645b3be2f 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -200,7 +200,8 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", ): - if task not in self._TASK_TO_COMMON_OUTPUTS: + # Isn't this check useless? + if task not in self._TASK_TO_COMMON_OUTPUTS and task != "text-to-speech": raise ValueError( f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}" ) @@ -808,7 +809,8 @@ def with_behavior( """ if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): behavior = ConfigBehavior(behavior) - return self.__class__( + + onnx_config = self.__class__( self._config, task=self.task, int_dtype=self.int_dtype, @@ -818,6 +820,8 @@ def with_behavior( behavior=behavior, preprocessors=self._preprocessors, ) + onnx_config.variant = self.variant + return onnx_config @property def outputs(self) -> Dict[str, Dict[int, str]]: diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index f637da07804..821a39cb06b 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -560,25 +560,26 @@ def remap(value): if is_torch_less_than_1_11: raise RuntimeError("The ONNX export using the PyTorch framework is only supported for v1.11+") else: - with config.patch_model_for_export(model, model_kwargs=model_kwargs): - check_dummy_inputs_are_allowed(model, dummy_inputs) - - inputs = config.ordered_inputs(model) - input_names = list(inputs.keys()) - output_names = list(config.outputs.keys()) - - # Export can work with named args but the dict containing named args has to be the last element of the args - # tuple. - onnx_export( - model, - (dummy_inputs,), - f=output.as_posix(), - input_names=input_names, - output_names=output_names, - dynamic_axes=dict(chain(inputs.items(), config.outputs.items())), - do_constant_folding=True, - opset_version=opset, - ) + with torch.no_grad(): + with config.patch_model_for_export(model, model_kwargs=model_kwargs): + check_dummy_inputs_are_allowed(model, dummy_inputs) + + inputs = config.ordered_inputs(model) + input_names = list(inputs.keys()) + output_names = list(config.outputs.keys()) + + # Export can work with named args but the dict containing named args has to be the last element of the args + # tuple. + onnx_export( + model, + (dummy_inputs,), + f=output.as_posix(), + input_names=input_names, + output_names=output_names, + dynamic_axes=dict(chain(inputs.items(), config.outputs.items())), + do_constant_folding=True, + opset_version=opset, + ) # check if external data was exported # TODO: this is quite inefficient as we load in memory if models are <2GB without external data diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index d339f565207..bbffcec529f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -24,6 +24,7 @@ BloomDummyPastKeyValuesGenerator, DummyAudioInputGenerator, DummyDecoderTextInputGenerator, + DummyInputGenerator, DummyPastKeyValuesGenerator, DummyPix2StructInputGenerator, DummyPointsGenerator, @@ -56,7 +57,7 @@ VisionOnnxConfig, ) from .model_patcher import SAMModelPatcher, SpeechT5ModelPatcher, WavLMModelPatcher -from ...utils import DummyInputGenerator + if TYPE_CHECKING: from transformers import PretrainedConfig @@ -1158,7 +1159,7 @@ def __init__( self.sequence_length = sequence_length self.speaker_embedding_dim = normalized_config.speaker_embedding_dim - self.num_mel_bins = normalized_config.speaker_embedding_dim + self.num_mel_bins = normalized_config.num_mel_bins def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): if input_name == "output_sequence": @@ -1182,11 +1183,12 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): # TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943), # so we won't support for now. - NORMALIZED_CONFIG_CLASS = None + NORMALIZED_CONFIG_CLASS = NormalizedConfig DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator, T5DummySeq2SeqPastKeyValuesGenerator, + DummySpeechT5InputGenerator, ) DUMMY_PKV_GENERATOR_CLASS = T5DummySeq2SeqPastKeyValuesGenerator @@ -1198,6 +1200,30 @@ class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): } DEFAULT_VARIANT = "with-past" + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + is_postnet_and_vocoder: bool = False, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=preprocessors, + ) + self.is_postnet_and_vocoder = is_postnet_and_vocoder + @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = {} @@ -1210,7 +1236,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2573 common_inputs["output_sequence"] = {1: "decoder_sequence_length"} common_inputs["speaker_embeddings"] = {} # No dynamic shape here. - common_inputs["encoder_hidden_states"] = {1: "encoder_sequence_length"} + common_inputs["encoder_outputs"] = {1: "encoder_sequence_length"} common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} if self.variant == "with-past" and self.use_past_in_inputs: @@ -1229,12 +1255,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]: def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = {} if self._behavior is ConfigBehavior.ENCODER: - common_outputs["encoder_hidden_states"] = {1: "encoder_sequence_length"} + common_outputs["encoder_outputs"] = {1: "encoder_sequence_length"} common_outputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} elif self._behavior is ConfigBehavior.DECODER: - common_outputs["output_sequence"] = {1: "decoder_sequence_length + 1"} - common_outputs["prob"] = {} # No dynamic shape here. + common_outputs["output_sequence_out"] = {1: "decoder_sequence_length + 1"} common_outputs["spectrum"] = {} # No dynamic shape here. + common_outputs["prob"] = {} # No dynamic shape here. if self.variant == "with-past" and self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. @@ -1253,6 +1279,23 @@ def patch_model_for_export( ) -> "ModelPatcher": return SpeechT5ModelPatcher(self, model, model_kwargs=model_kwargs) + @property + def torch_to_onnx_input_map(self) -> Dict[str, str]: + return { + # "decoder_input_ids": "input_ids", + "encoder_outputs": "encoder_hidden_states", + # "attention_mask": "encoder_attention_mask", + } + + def overwrite_shape_and_generate_input( + self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict + ): + dummy_input_gen.batch_size = 1 + dummy_input = dummy_input_gen.generate( + input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + ) + return dummy_input + class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 33d92ebb6b6..9679159dd7a 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -364,6 +364,8 @@ def __init__( ): super().__init__(config, model, model_kwargs) + model.vocoder = model_kwargs["vocoder_model"] + def patched_forward( input_ids=None, speaker_embeddings=None, @@ -371,6 +373,7 @@ def patched_forward( past_key_values=None, output_sequence=None, spectrogram=None, + encoder_attention_mask=None, ): use_cache = self.real_config.use_past and self.real_config.variant == "with-past" if self.real_config._behavior == "encoder": @@ -387,11 +390,14 @@ def patched_forward( encoder_out[0].shape[1], encoder_attention_mask ) - # TODO: that is wrong? - return {"encoder_out": encoder_out, "encoder_attention_mask": encoder_attention_mask} + return { + "encoder_outputs": encoder_out.last_hidden_state, + "encoder_attention_mask": encoder_attention_mask, + } - elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: - encoder_hidden_states = encoder_outputs.last_hidden_state + elif self.real_config._behavior == "decoder": + # TODO: and self.real_config.use_past_in_inputs + encoder_hidden_states = encoder_outputs[0] decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) @@ -426,9 +432,9 @@ def patched_forward( prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) return { + "output_sequence_out": output_sequence, + "spectrum": spectrum, "prob": prob, - "output_sequence": output_sequence, - "spectrum": spectrum # TODO: PKV here } elif self.real_config.is_postnet_and_vocoder: @@ -438,8 +444,10 @@ def patched_forward( spectrogram = model.speech_decoder_postnet.postnet(spectrogram) spectrogram = spectrogram.squeeze(0) - waveform = model_kwargs["vocoder"](spectrogram) + waveform = model.vocoder(spectrogram) return {"waveform": waveform} + else: + raise ValueError("Should not happen") self.patched_forward = patched_forward diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 1ae682cce9f..aa743db3b44 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -395,7 +395,9 @@ def get_speecht5_models_for_export( model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig", model_kwargs: Optional[Dict] ): if model_kwargs is None or "vocoder" not in model_kwargs: - raise ValueError("The ONNX export of SpeechT5 requires the model_kwargs `vocoder` to be set.") + raise ValueError( + 'The ONNX export of SpeechT5 requires a vocoder. Please pass `--model-kwargs \'{"vocoder": "vocoder_model_name_or_path"}\'` from the command line, or `model_kwargs={"vocoder": "vocoder_model_name_or_path"}` if calling main_export.' + ) models_for_export = {} @@ -406,7 +408,8 @@ def get_speecht5_models_for_export( if config.variant == "with-past": models_for_export["decoder_with_past_model"] = model - vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"]) + # TODO: more flexibility in the vocoder class? + vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"]).eval() model_kwargs["vocoder_model"] = vocoder models_for_export["decoder_postnet_and_vocoder"] = model @@ -425,7 +428,17 @@ def get_speecht5_models_for_export( decoder_onnx_config_with_past, ) - postnet_and_vocoder_onnx_config = config.__class__(..., is_vocoder=True) + postnet_and_vocoder_onnx_config = config.__class__( + config._config, + task=config.task, + int_dtype=config.int_dtype, + float_dtype=config.float_dtype, + use_past=use_past, + use_past_in_inputs=False, # Irrelevant here. + behavior=config._behavior, # Irrelevant here. + preprocessors=config._preprocessors, + is_postnet_and_vocoder=True, + ) models_for_export["decoder_postnet_and_vocoder"] = ( models_for_export["decoder_postnet_and_vocoder"], postnet_and_vocoder_onnx_config, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 227c12315d9..72bbb2e618f 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -323,6 +323,7 @@ class DummyTextInputGenerator(DummyInputGenerator): SUPPORTED_INPUT_NAMES = ( "input_ids", "attention_mask", + "encoder_attention_mask", "token_type_ids", "position_ids", )