From be26f711926c99605f303fc3059b146edddf5617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 21 Sep 2023 14:47:39 +0200 Subject: [PATCH] wip bis --- optimum/exporters/onnx/__main__.py | 6 ++ optimum/exporters/onnx/model_configs.py | 114 +++++++++++++++++++++++- optimum/exporters/onnx/model_patcher.py | 106 +++++++++++++++++++++- optimum/exporters/onnx/utils.py | 46 +++++++++- optimum/exporters/tasks.py | 8 +- 5 files changed, 265 insertions(+), 15 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 16a18afc552..8b93d487dde 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -38,6 +38,7 @@ get_decoder_models_for_export, get_encoder_decoder_models_for_export, get_sam_models_for_export, + get_speecht5_models_for_export, get_stable_diffusion_models_for_export, ) @@ -69,6 +70,7 @@ def _get_submodels_and_onnx_configs( fn_get_submodels: Optional[Callable] = None, preprocessors: Optional[List[Any]] = None, no_position_ids: bool = False, + model_kwargs: Optional[Dict] = None, ): is_stable_diffusion = "stable-diffusion" in task if not custom_architecture: @@ -99,6 +101,7 @@ def _get_submodels_and_onnx_configs( ) logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}") + # TODO: this succession of if/else strongly suggests a refactor is needed. if ( model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS) @@ -109,6 +112,8 @@ def _get_submodels_and_onnx_configs( models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) elif model.config.model_type == "sam": models_and_onnx_configs = get_sam_models_for_export(model, onnx_config) + elif model.config.model_type == "speecht5": + models_and_onnx_configs = get_speecht5_models_for_export(model, onnx_config, model_kwargs) else: models_and_onnx_configs = {"model": (model, onnx_config)} @@ -425,6 +430,7 @@ def main_export( preprocessors=preprocessors, _variant=_variant, no_position_ids=no_position_ids, + model_kwargs=model_kwargs, ) if not is_stable_diffusion: diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 499d4f6f03d..febb4c40073 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -55,7 +55,7 @@ TextSeq2SeqOnnxConfig, VisionOnnxConfig, ) -from .model_patcher import SAMModelPatcher, WavLMModelPatcher +from .model_patcher import SAMModelPatcher, SpeechT5ModelPatcher, WavLMModelPatcher if TYPE_CHECKING: @@ -1143,10 +1143,116 @@ def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs["last_hidden_state"][1] = f"{common_outputs['last_hidden_state'][1]} / 2" return common_outputs -class SpeechT5OnnxConfig(): - NORMALIZED_CONFIG_CLASS = - +class DummySpeechT5InputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("output_sequence", "speaker_embeddings", "spectrogram") + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + **kwargs, + ): + self.task = task + self.batch_size = 1 # TODO: SpeechT5 does not support batch inference in Transformers for now. + + self.sequence_length = sequence_length + self.speaker_embedding_dim = normalized_config.speaker_embedding_dim + self.num_mel_bins = normalized_config.speaker_embedding_dim + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "output_sequence": + shape = [self.batch_size, self.sequence_length, self.num_mel_bins] + elif input_name == "speaker_embeddings": + shape = [self.batch_size, self.speaker_embedding_dim] + elif input_name == "spectrogram": + shape = [20, self.num_mel_bins] # NOTE: the first axis length is arbitrary and dynamic + else: + raise ValueError(f"Unsupported input {input_name} for DummySpeechT5InputGenerator") + + return self.random_float_tensor( + shape=shape, + min_value=0, + max_value=1, + framework=framework, + dtype=float_dtype, + ) + + +class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): + # TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943), + # so we won't support for now. + NORMALIZED_CONFIG_CLASS = None + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTextInputGenerator, + DummySeq2SeqDecoderTextInputGenerator, + T5DummySeq2SeqPastKeyValuesGenerator, + ) + DUMMY_PKV_GENERATOR_CLASS = T5DummySeq2SeqPastKeyValuesGenerator + + # TODO: DO NOT CUT OUTPUT_SEQUENCE LENGTH WITH PAST!!!!! + + VARIANTS = { + "transformers-like": "The following components are exported following Transformers implementation:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.", + "without-cache": "The same as `transformers-like`, without KV cache support. This is not a recommende export as slower than `transformers-like`.", + } + DEFAULT_VARIANT = "transformers-like" + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = {} + + # Batched inference is not supported in Transformers. + if self._behavior is ConfigBehavior.ENCODER: + common_inputs["input_ids"] = {1: "encoder_sequence_length"} + elif self._behavior is ConfigBehavior.DECODER: + # NOTE: even when past is used, the decoder takes the full sequence as input as the prenet seem to require it: + # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2573 + common_inputs["output_sequence"] = {1: "decoder_sequence_length"} + common_inputs["speaker_embeddings"] = {} # No dynamic shape here. + common_inputs["encoder_hidden_states"] = {1: "encoder_sequence_length"} + common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} + + if self.variant == "transformers-like" and self.use_past_in_inputs: + # TODO: check PKV shape + self.add_past_key_values(common_inputs, direction="inputs") + elif self.is_postnet_and_vocoder: + common_inputs["spectrogram"] = {0: "n_spectrums x reduction_factor"} + else: + raise ValueError( + "self._behavior is neither encoder or decoder, and is_postnet_and_vocoder=False. This should not happen." + ) + + return common_inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = {} + if self._behavior is ConfigBehavior.ENCODER: + common_outputs["encoder_hidden_states"] = {1: "encoder_sequence_length"} + common_outputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} + elif self._behavior is ConfigBehavior.DECODER: + common_outputs["output_sequence"] = {1: "decoder_sequence_length + 1"} + common_outputs["prob"] = {} # No dynamic shape here. + common_outputs["spectrum"] = {} # No dynamic shape here. + + if self.variant == "transformers-like" and self.use_past: + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. + self.add_past_key_values(common_outputs, direction="outputs") + elif self.is_postnet_and_vocoder: + common_outputs["waveform"] = {0: "n_samples"} + else: + raise ValueError( + "self._behavior is neither encoder or decoder, and is_postnet_and_vocoder=False. This should not happen." + ) + + return common_outputs + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return SpeechT5ModelPatcher(self, model, model_kwargs=model_kwargs) class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index e6b50b6dc08..b9abe29421a 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -15,8 +15,9 @@ import dataclasses import functools import inspect -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from transformers.utils import is_torch_available @@ -34,6 +35,18 @@ logger = logging.get_logger(__name__) +def get_argument(argument_name: str, args: List[Any], kwargs: Dict[str, Any], forward_signature): + """ + Get the argument argument_name from the args and kwargs according to the signature forward_signature. + """ + args = list(args) + if argument_name in forward_signature.parameters: + argument_index = list(forward_signature.parameters.keys()).index(argument_name) + return args[argument_index] + else: + return kwargs[argument_name] + + def override_arguments(args, kwargs, forward_signature, model_kwargs: Dict[str, Any]): """ Override the args and kwargs with the argument values from model_kwargs, following the signature forward_signature corresponding to args and kwargs. @@ -286,9 +299,7 @@ def patched_forward( **kwargs, ) elif config.variant == "split": - # return_dict = get_argument(args, kwargs, signature, "return_dict") if config.vision_encoder: - # pixel_values = get_argument(args, kwargs, signature, "pixel_values") image_positional_embeddings = model.get_image_wide_positional_embeddings() # repeat with batch size @@ -342,3 +353,92 @@ def patched_forward( return {"iou_scores": iou_predictions, "pred_masks": low_res_masks} self.patched_forward = patched_forward + + +class SpeechT5ModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + + def patched_forward( + input_ids=None, + speaker_embeddings=None, + encoder_outputs=None, + past_key_values=None, + output_sequence=None, + spectrogram=None, + ): + use_cache = self.real_config.use_past and self.real_config.variant == "transformers-like" + if self.real_config._behavior == "encoder": + encoder_attention_mask = torch.ones_like(input_ids) + + encoder_out = model.speecht5.encoder( + input_values=input_ids, + attention_mask=encoder_attention_mask, + return_dict=True, + ) + # downsample encoder attention mask + if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): + encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( + encoder_out[0].shape[1], encoder_attention_mask + ) + + # TODO: that is wrong? + return {"encoder_out": encoder_out, "encoder_attention_mask": encoder_attention_mask} + + elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: + encoder_hidden_states = encoder_outputs.last_hidden_state + + decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) + + # Run the decoder layers on the last element of the prenet output. + decoder_out = model.speecht5.decoder.wrapped_decoder( + hidden_states=decoder_hidden_states[:, -1:], + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=False, + return_dict=True, + ) + + last_decoder_output = decoder_out.last_hidden_state[0, -1] + past_key_values = decoder_out.past_key_values + + # Predict the new mel spectrum for this step in the sequence. + spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) + spectrum = spectrum.view(model.config.reduction_factor, model.config.num_mel_bins) + + # NOTE: extending the spectrogram should is to be handled outside of the ONNX. + # spectrogram.append(spectrum) + + # Extend the output sequence with the new mel spectrum. + output_sequence = torch.cat( + (output_sequence, spectrum[-1].view(1, 1, model.config.num_mel_bins)), dim=1 + ) + + # Predict the probability that this is the stop token. + prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) + + return { + "prob": prob, + "output_sequence": output_sequence, + "spectrum": spectrum + # TODO: PKV here + } + elif self.real_config.is_postnet_and_vocoder: + # spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0) + spectrogram = spectrogram.unsqueeze(0) + spectrogram = model.speech_decoder_postnet.postnet(spectrogram) + spectrogram = spectrogram.squeeze(0) + + waveform = model_kwargs["vocoder"](spectrogram) + + return {"waveform": waveform} + + self.patched_forward = patched_forward diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 55850451aa7..a24cde52135 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -19,6 +19,7 @@ import torch from packaging import version +from transformers.models.speecht5.modeling_speecht5 import SpeechT5HifiGan from transformers.utils import is_tf_available, is_torch_available from ...utils import ( @@ -361,7 +362,7 @@ def _get_submodels_for_export_sam(model, variant): if variant == "monolith": models_for_export["model"] = model else: - # We use the model patcher to patch their forward method. + # We rather use the model patcher to patch their forward method. models_for_export["vision_encoder"] = model models_for_export["prompt_encoder_mask_decoder"] = model @@ -390,6 +391,49 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel return models_for_export +def get_speecht5_models_for_export( + model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig", model_kwargs: Optional[Dict] +): + if model_kwargs is None or "vocoder" not in model_kwargs: + raise ValueError("The ONNX export of SpeechT5 requires the model_kwargs `vocoder` to be set.") + + models_for_export = {} + + # We rather use the model patcher to patch their forward method. + models_for_export["encoder_model"] = model + models_for_export["decoder_model"] = model + + if config.variant == "transformers-like": + models_for_export["decoder_with_past_model"] = model + + vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"]) + model_kwargs["vocoder_model"] = vocoder + + models_for_export["decoder_postnet_and_vocoder"] = model + + encoder_onnx_config = config.with_behavior("encoder") + + use_past = config.variant == "transformers-like" + decoder_onnx_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False) + + models_for_export[ONNX_ENCODER_NAME] = (models_for_export[ONNX_ENCODER_NAME], encoder_onnx_config) + models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], decoder_onnx_config) + if config.variant == "transformers-like": + decoder_onnx_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( + models_for_export[ONNX_DECODER_WITH_PAST_NAME], + decoder_onnx_config_with_past, + ) + + postnet_and_vocoder_onnx_config = config.__class__(..., is_vocoder=True) + models_for_export["decoder_postnet_and_vocoder"] = ( + models_for_export["decoder_postnet_and_vocoder"], + postnet_and_vocoder_onnx_config, + ) + + return models_for_export + + def override_diffusers_2_0_attn_processors(model): for _, submodule in model.named_modules(): if isinstance(submodule, Attention): diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 99ba5f1f6e1..aff43b07ad9 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -265,8 +265,6 @@ class TasksManager: ("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"), # VisionEncoderDecoderModel is not registered in AutoModelForDocumentQuestionAnswering ("pt", "vision-encoder-decoder", "document-question-answering"): ("transformers", "VisionEncoderDecoderModel"), - # audio-to-audio task has no AutoModel class. - ("pt", "speecht5", "audio-to-audio"): ("transformers", "SpeechT5ForSpeechToSpeech"), } # TODO: why feature-extraction-with-past is here? @@ -841,13 +839,9 @@ class TasksManager: "automatic-speech-recognition-with-past", onnx="Speech2TextOnnxConfig", ), + # TODO: SpeechT5 can also support audio-to-audio and automatic-speech-recognition. "speecht5": supported_tasks_mapping( - "audio-to-audio", - "audio-to-audio-with-past", - "automatic-speech-recognition", - "automatic-speech-recognition-with-past", "text-to-speech", - "text-to-speech-with-past", onnx="SpeechT5OnnxConfig", ), "splinter": supported_tasks_mapping(