Skip to content

Commit

Permalink
working export
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Sep 21, 2023
1 parent d181ad2 commit 54d3bc7
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 39 deletions.
7 changes: 7 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Defines the command line for the export with ONNX."""

import argparse
import json
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -143,6 +144,11 @@ def parse_args_onnx(parser):
"Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)
optional_group.add_argument(
"--model-kwargs",
type=json.loads,
help=("Any kwargs passed to the model forward, or used to customize the export for a given model."),
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
Expand Down Expand Up @@ -256,5 +262,6 @@ def run(self):
_variant=self.args.variant,
library_name=self.args.library_name,
no_position_ids=self.args.no_position_ids,
model_kwargs=self.args.model_kwargs,
**input_shapes,
)
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _get_submodels_and_onnx_configs(

onnx_config.variant = _variant
all_variants = "\n".join(
[f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
[f" - {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
)
logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}")

Expand Down
8 changes: 6 additions & 2 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def __init__(
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
if task not in self._TASK_TO_COMMON_OUTPUTS:
# Isn't this check useless?
if task not in self._TASK_TO_COMMON_OUTPUTS and task != "text-to-speech":
raise ValueError(
f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}"
)
Expand Down Expand Up @@ -808,7 +809,8 @@ def with_behavior(
"""
if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior):
behavior = ConfigBehavior(behavior)
return self.__class__(

onnx_config = self.__class__(
self._config,
task=self.task,
int_dtype=self.int_dtype,
Expand All @@ -818,6 +820,8 @@ def with_behavior(
behavior=behavior,
preprocessors=self._preprocessors,
)
onnx_config.variant = self.variant
return onnx_config

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
Expand Down
39 changes: 20 additions & 19 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,25 +560,26 @@ def remap(value):
if is_torch_less_than_1_11:
raise RuntimeError("The ONNX export using the PyTorch framework is only supported for v1.11+")
else:
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)

inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dict(chain(inputs.items(), config.outputs.items())),
do_constant_folding=True,
opset_version=opset,
)
with torch.no_grad():
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)

inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dict(chain(inputs.items(), config.outputs.items())),
do_constant_folding=True,
opset_version=opset,
)

# check if external data was exported
# TODO: this is quite inefficient as we load in memory if models are <2GB without external data
Expand Down
57 changes: 50 additions & 7 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BloomDummyPastKeyValuesGenerator,
DummyAudioInputGenerator,
DummyDecoderTextInputGenerator,
DummyInputGenerator,
DummyPastKeyValuesGenerator,
DummyPix2StructInputGenerator,
DummyPointsGenerator,
Expand Down Expand Up @@ -56,7 +57,7 @@
VisionOnnxConfig,
)
from .model_patcher import SAMModelPatcher, SpeechT5ModelPatcher, WavLMModelPatcher
from ...utils import DummyInputGenerator


if TYPE_CHECKING:
from transformers import PretrainedConfig
Expand Down Expand Up @@ -1158,7 +1159,7 @@ def __init__(

self.sequence_length = sequence_length
self.speaker_embedding_dim = normalized_config.speaker_embedding_dim
self.num_mel_bins = normalized_config.speaker_embedding_dim
self.num_mel_bins = normalized_config.num_mel_bins

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "output_sequence":
Expand All @@ -1182,11 +1183,12 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943),
# so we won't support for now.
NORMALIZED_CONFIG_CLASS = None
NORMALIZED_CONFIG_CLASS = NormalizedConfig
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTextInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
T5DummySeq2SeqPastKeyValuesGenerator,
DummySpeechT5InputGenerator,
)
DUMMY_PKV_GENERATOR_CLASS = T5DummySeq2SeqPastKeyValuesGenerator

Expand All @@ -1198,6 +1200,30 @@ class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
}
DEFAULT_VARIANT = "with-past"

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
is_postnet_and_vocoder: bool = False,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=preprocessors,
)
self.is_postnet_and_vocoder = is_postnet_and_vocoder

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}
Expand All @@ -1210,7 +1236,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2573
common_inputs["output_sequence"] = {1: "decoder_sequence_length"}
common_inputs["speaker_embeddings"] = {} # No dynamic shape here.
common_inputs["encoder_hidden_states"] = {1: "encoder_sequence_length"}
common_inputs["encoder_outputs"] = {1: "encoder_sequence_length"}
common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"}

if self.variant == "with-past" and self.use_past_in_inputs:
Expand All @@ -1229,12 +1255,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {}
if self._behavior is ConfigBehavior.ENCODER:
common_outputs["encoder_hidden_states"] = {1: "encoder_sequence_length"}
common_outputs["encoder_outputs"] = {1: "encoder_sequence_length"}
common_outputs["encoder_attention_mask"] = {1: "encoder_sequence_length"}
elif self._behavior is ConfigBehavior.DECODER:
common_outputs["output_sequence"] = {1: "decoder_sequence_length + 1"}
common_outputs["prob"] = {} # No dynamic shape here.
common_outputs["output_sequence_out"] = {1: "decoder_sequence_length + 1"}
common_outputs["spectrum"] = {} # No dynamic shape here.
common_outputs["prob"] = {} # No dynamic shape here.

if self.variant == "with-past" and self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
Expand All @@ -1253,6 +1279,23 @@ def patch_model_for_export(
) -> "ModelPatcher":
return SpeechT5ModelPatcher(self, model, model_kwargs=model_kwargs)

@property
def torch_to_onnx_input_map(self) -> Dict[str, str]:
return {
# "decoder_input_ids": "input_ids",
"encoder_outputs": "encoder_hidden_states",
# "attention_mask": "encoder_attention_mask",
}

def overwrite_shape_and_generate_input(
self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict
):
dummy_input_gen.batch_size = 1
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)
return dummy_input


class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
Expand Down
22 changes: 15 additions & 7 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,16 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

model.vocoder = model_kwargs["vocoder_model"]

def patched_forward(
input_ids=None,
speaker_embeddings=None,
encoder_outputs=None,
past_key_values=None,
output_sequence=None,
spectrogram=None,
encoder_attention_mask=None,
):
use_cache = self.real_config.use_past and self.real_config.variant == "with-past"
if self.real_config._behavior == "encoder":
Expand All @@ -387,11 +390,14 @@ def patched_forward(
encoder_out[0].shape[1], encoder_attention_mask
)

# TODO: that is wrong?
return {"encoder_out": encoder_out, "encoder_attention_mask": encoder_attention_mask}
return {
"encoder_outputs": encoder_out.last_hidden_state,
"encoder_attention_mask": encoder_attention_mask,
}

elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
encoder_hidden_states = encoder_outputs.last_hidden_state
elif self.real_config._behavior == "decoder":
# TODO: and self.real_config.use_past_in_inputs
encoder_hidden_states = encoder_outputs[0]

decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)

Expand Down Expand Up @@ -426,9 +432,9 @@ def patched_forward(
prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))

return {
"output_sequence_out": output_sequence,
"spectrum": spectrum,
"prob": prob,
"output_sequence": output_sequence,
"spectrum": spectrum
# TODO: PKV here
}
elif self.real_config.is_postnet_and_vocoder:
Expand All @@ -438,8 +444,10 @@ def patched_forward(
spectrogram = model.speech_decoder_postnet.postnet(spectrogram)
spectrogram = spectrogram.squeeze(0)

waveform = model_kwargs["vocoder"](spectrogram)
waveform = model.vocoder(spectrogram)

return {"waveform": waveform}
else:
raise ValueError("Should not happen")

self.patched_forward = patched_forward
19 changes: 16 additions & 3 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ def get_speecht5_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig", model_kwargs: Optional[Dict]
):
if model_kwargs is None or "vocoder" not in model_kwargs:
raise ValueError("The ONNX export of SpeechT5 requires the model_kwargs `vocoder` to be set.")
raise ValueError(
'The ONNX export of SpeechT5 requires a vocoder. Please pass `--model-kwargs \'{"vocoder": "vocoder_model_name_or_path"}\'` from the command line, or `model_kwargs={"vocoder": "vocoder_model_name_or_path"}` if calling main_export.'
)

models_for_export = {}

Expand All @@ -406,7 +408,8 @@ def get_speecht5_models_for_export(
if config.variant == "with-past":
models_for_export["decoder_with_past_model"] = model

vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"])
# TODO: more flexibility in the vocoder class?
vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"]).eval()
model_kwargs["vocoder_model"] = vocoder

models_for_export["decoder_postnet_and_vocoder"] = model
Expand All @@ -425,7 +428,17 @@ def get_speecht5_models_for_export(
decoder_onnx_config_with_past,
)

postnet_and_vocoder_onnx_config = config.__class__(..., is_vocoder=True)
postnet_and_vocoder_onnx_config = config.__class__(
config._config,
task=config.task,
int_dtype=config.int_dtype,
float_dtype=config.float_dtype,
use_past=use_past,
use_past_in_inputs=False, # Irrelevant here.
behavior=config._behavior, # Irrelevant here.
preprocessors=config._preprocessors,
is_postnet_and_vocoder=True,
)
models_for_export["decoder_postnet_and_vocoder"] = (
models_for_export["decoder_postnet_and_vocoder"],
postnet_and_vocoder_onnx_config,
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ class DummyTextInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = (
"input_ids",
"attention_mask",
"encoder_attention_mask",
"token_type_ids",
"position_ids",
)
Expand Down

0 comments on commit 54d3bc7

Please sign in to comment.