Skip to content

Commit

Permalink
wip bis
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Sep 21, 2023
1 parent 2dd5209 commit be26f71
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 15 deletions.
6 changes: 6 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
get_sam_models_for_export,
get_speecht5_models_for_export,
get_stable_diffusion_models_for_export,
)

Expand Down Expand Up @@ -69,6 +70,7 @@ def _get_submodels_and_onnx_configs(
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
model_kwargs: Optional[Dict] = None,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand Down Expand Up @@ -99,6 +101,7 @@ def _get_submodels_and_onnx_configs(
)
logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}")

# TODO: this succession of if/else strongly suggests a refactor is needed.
if (
model.config.is_encoder_decoder
and task.startswith(TasksManager._ENCODER_DECODER_TASKS)
Expand All @@ -109,6 +112,8 @@ def _get_submodels_and_onnx_configs(
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
elif model.config.model_type == "speecht5":
models_and_onnx_configs = get_speecht5_models_for_export(model, onnx_config, model_kwargs)
else:
models_and_onnx_configs = {"model": (model, onnx_config)}

Expand Down Expand Up @@ -425,6 +430,7 @@ def main_export(
preprocessors=preprocessors,
_variant=_variant,
no_position_ids=no_position_ids,
model_kwargs=model_kwargs,
)

if not is_stable_diffusion:
Expand Down
114 changes: 110 additions & 4 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .model_patcher import SAMModelPatcher, WavLMModelPatcher
from .model_patcher import SAMModelPatcher, SpeechT5ModelPatcher, WavLMModelPatcher


if TYPE_CHECKING:
Expand Down Expand Up @@ -1143,10 +1143,116 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs["last_hidden_state"][1] = f"{common_outputs['last_hidden_state'][1]} / 2"
return common_outputs

class SpeechT5OnnxConfig():
NORMALIZED_CONFIG_CLASS =


class DummySpeechT5InputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("output_sequence", "speaker_embeddings", "spectrogram")

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
self.task = task
self.batch_size = 1 # TODO: SpeechT5 does not support batch inference in Transformers for now.

self.sequence_length = sequence_length
self.speaker_embedding_dim = normalized_config.speaker_embedding_dim
self.num_mel_bins = normalized_config.speaker_embedding_dim

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "output_sequence":
shape = [self.batch_size, self.sequence_length, self.num_mel_bins]
elif input_name == "speaker_embeddings":
shape = [self.batch_size, self.speaker_embedding_dim]
elif input_name == "spectrogram":
shape = [20, self.num_mel_bins] # NOTE: the first axis length is arbitrary and dynamic
else:
raise ValueError(f"Unsupported input {input_name} for DummySpeechT5InputGenerator")

return self.random_float_tensor(
shape=shape,
min_value=0,
max_value=1,
framework=framework,
dtype=float_dtype,
)


class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943),
# so we won't support for now.
NORMALIZED_CONFIG_CLASS = None
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTextInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
T5DummySeq2SeqPastKeyValuesGenerator,
)
DUMMY_PKV_GENERATOR_CLASS = T5DummySeq2SeqPastKeyValuesGenerator

# TODO: DO NOT CUT OUTPUT_SEQUENCE LENGTH WITH PAST!!!!!

VARIANTS = {
"transformers-like": "The following components are exported following Transformers implementation:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.",
"without-cache": "The same as `transformers-like`, without KV cache support. This is not a recommende export as slower than `transformers-like`.",
}
DEFAULT_VARIANT = "transformers-like"

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}

# Batched inference is not supported in Transformers.
if self._behavior is ConfigBehavior.ENCODER:
common_inputs["input_ids"] = {1: "encoder_sequence_length"}
elif self._behavior is ConfigBehavior.DECODER:
# NOTE: even when past is used, the decoder takes the full sequence as input as the prenet seem to require it:
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2573
common_inputs["output_sequence"] = {1: "decoder_sequence_length"}
common_inputs["speaker_embeddings"] = {} # No dynamic shape here.
common_inputs["encoder_hidden_states"] = {1: "encoder_sequence_length"}
common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"}

if self.variant == "transformers-like" and self.use_past_in_inputs:
# TODO: check PKV shape
self.add_past_key_values(common_inputs, direction="inputs")
elif self.is_postnet_and_vocoder:
common_inputs["spectrogram"] = {0: "n_spectrums x reduction_factor"}
else:
raise ValueError(
"self._behavior is neither encoder or decoder, and is_postnet_and_vocoder=False. This should not happen."
)

return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {}
if self._behavior is ConfigBehavior.ENCODER:
common_outputs["encoder_hidden_states"] = {1: "encoder_sequence_length"}
common_outputs["encoder_attention_mask"] = {1: "encoder_sequence_length"}
elif self._behavior is ConfigBehavior.DECODER:
common_outputs["output_sequence"] = {1: "decoder_sequence_length + 1"}
common_outputs["prob"] = {} # No dynamic shape here.
common_outputs["spectrum"] = {} # No dynamic shape here.

if self.variant == "transformers-like" and self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
elif self.is_postnet_and_vocoder:
common_outputs["waveform"] = {0: "n_samples"}
else:
raise ValueError(
"self._behavior is neither encoder or decoder, and is_postnet_and_vocoder=False. This should not happen."
)

return common_outputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SpeechT5ModelPatcher(self, model, model_kwargs=model_kwargs)


class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator):
Expand Down
106 changes: 103 additions & 3 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import dataclasses
import functools
import inspect
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
from transformers.utils import is_torch_available


Expand All @@ -34,6 +35,18 @@
logger = logging.get_logger(__name__)


def get_argument(argument_name: str, args: List[Any], kwargs: Dict[str, Any], forward_signature):
"""
Get the argument argument_name from the args and kwargs according to the signature forward_signature.
"""
args = list(args)
if argument_name in forward_signature.parameters:
argument_index = list(forward_signature.parameters.keys()).index(argument_name)
return args[argument_index]
else:
return kwargs[argument_name]


def override_arguments(args, kwargs, forward_signature, model_kwargs: Dict[str, Any]):
"""
Override the args and kwargs with the argument values from model_kwargs, following the signature forward_signature corresponding to args and kwargs.
Expand Down Expand Up @@ -286,9 +299,7 @@ def patched_forward(
**kwargs,
)
elif config.variant == "split":
# return_dict = get_argument(args, kwargs, signature, "return_dict")
if config.vision_encoder:
# pixel_values = get_argument(args, kwargs, signature, "pixel_values")
image_positional_embeddings = model.get_image_wide_positional_embeddings()

# repeat with batch size
Expand Down Expand Up @@ -342,3 +353,92 @@ def patched_forward(
return {"iou_scores": iou_predictions, "pred_masks": low_res_masks}

self.patched_forward = patched_forward


class SpeechT5ModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)

def patched_forward(
input_ids=None,
speaker_embeddings=None,
encoder_outputs=None,
past_key_values=None,
output_sequence=None,
spectrogram=None,
):
use_cache = self.real_config.use_past and self.real_config.variant == "transformers-like"
if self.real_config._behavior == "encoder":
encoder_attention_mask = torch.ones_like(input_ids)

encoder_out = model.speecht5.encoder(
input_values=input_ids,
attention_mask=encoder_attention_mask,
return_dict=True,
)
# downsample encoder attention mask
if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet):
encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask(
encoder_out[0].shape[1], encoder_attention_mask
)

# TODO: that is wrong?
return {"encoder_out": encoder_out, "encoder_attention_mask": encoder_attention_mask}

elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
encoder_hidden_states = encoder_outputs.last_hidden_state

decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)

# Run the decoder layers on the last element of the prenet output.
decoder_out = model.speecht5.decoder.wrapped_decoder(
hidden_states=decoder_hidden_states[:, -1:],
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=False,
return_dict=True,
)

last_decoder_output = decoder_out.last_hidden_state[0, -1]
past_key_values = decoder_out.past_key_values

# Predict the new mel spectrum for this step in the sequence.
spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
spectrum = spectrum.view(model.config.reduction_factor, model.config.num_mel_bins)

# NOTE: extending the spectrogram should is to be handled outside of the ONNX.
# spectrogram.append(spectrum)

# Extend the output sequence with the new mel spectrum.
output_sequence = torch.cat(
(output_sequence, spectrum[-1].view(1, 1, model.config.num_mel_bins)), dim=1
)

# Predict the probability that this is the stop token.
prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))

return {
"prob": prob,
"output_sequence": output_sequence,
"spectrum": spectrum
# TODO: PKV here
}
elif self.real_config.is_postnet_and_vocoder:
# spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0)
spectrogram = spectrogram.unsqueeze(0)
spectrogram = model.speech_decoder_postnet.postnet(spectrogram)
spectrogram = spectrogram.squeeze(0)

waveform = model_kwargs["vocoder"](spectrogram)

return {"waveform": waveform}

self.patched_forward = patched_forward
46 changes: 45 additions & 1 deletion optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
from packaging import version
from transformers.models.speecht5.modeling_speecht5 import SpeechT5HifiGan
from transformers.utils import is_tf_available, is_torch_available

from ...utils import (
Expand Down Expand Up @@ -361,7 +362,7 @@ def _get_submodels_for_export_sam(model, variant):
if variant == "monolith":
models_for_export["model"] = model
else:
# We use the model patcher to patch their forward method.
# We rather use the model patcher to patch their forward method.
models_for_export["vision_encoder"] = model
models_for_export["prompt_encoder_mask_decoder"] = model

Expand Down Expand Up @@ -390,6 +391,49 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel
return models_for_export


def get_speecht5_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig", model_kwargs: Optional[Dict]
):
if model_kwargs is None or "vocoder" not in model_kwargs:
raise ValueError("The ONNX export of SpeechT5 requires the model_kwargs `vocoder` to be set.")

models_for_export = {}

# We rather use the model patcher to patch their forward method.
models_for_export["encoder_model"] = model
models_for_export["decoder_model"] = model

if config.variant == "transformers-like":
models_for_export["decoder_with_past_model"] = model

vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"])
model_kwargs["vocoder_model"] = vocoder

models_for_export["decoder_postnet_and_vocoder"] = model

encoder_onnx_config = config.with_behavior("encoder")

use_past = config.variant == "transformers-like"
decoder_onnx_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False)

models_for_export[ONNX_ENCODER_NAME] = (models_for_export[ONNX_ENCODER_NAME], encoder_onnx_config)
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], decoder_onnx_config)
if config.variant == "transformers-like":
decoder_onnx_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (
models_for_export[ONNX_DECODER_WITH_PAST_NAME],
decoder_onnx_config_with_past,
)

postnet_and_vocoder_onnx_config = config.__class__(..., is_vocoder=True)
models_for_export["decoder_postnet_and_vocoder"] = (
models_for_export["decoder_postnet_and_vocoder"],
postnet_and_vocoder_onnx_config,
)

return models_for_export


def override_diffusers_2_0_attn_processors(model):
for _, submodule in model.named_modules():
if isinstance(submodule, Attention):
Expand Down
8 changes: 1 addition & 7 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,6 @@ class TasksManager:
("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"),
# VisionEncoderDecoderModel is not registered in AutoModelForDocumentQuestionAnswering
("pt", "vision-encoder-decoder", "document-question-answering"): ("transformers", "VisionEncoderDecoderModel"),
# audio-to-audio task has no AutoModel class.
("pt", "speecht5", "audio-to-audio"): ("transformers", "SpeechT5ForSpeechToSpeech"),
}

# TODO: why feature-extraction-with-past is here?
Expand Down Expand Up @@ -841,13 +839,9 @@ class TasksManager:
"automatic-speech-recognition-with-past",
onnx="Speech2TextOnnxConfig",
),
# TODO: SpeechT5 can also support audio-to-audio and automatic-speech-recognition.
"speecht5": supported_tasks_mapping(
"audio-to-audio",
"audio-to-audio-with-past",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
"text-to-speech",
"text-to-speech-with-past",
onnx="SpeechT5OnnxConfig",
),
"splinter": supported_tasks_mapping(
Expand Down

0 comments on commit be26f71

Please sign in to comment.