Skip to content

Commit

Permalink
working with-past version
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Sep 21, 2023
1 parent 54d3bc7 commit b107b2d
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 47 deletions.
58 changes: 29 additions & 29 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError
from .base import OnnxConfig
from .model_configs import SpeechT5OnnxConfig
from .utils import PickableInferenceSession, recursive_to_device


Expand Down Expand Up @@ -142,15 +143,13 @@ def validate_models_outputs(
if use_subprocess:
logger.info("Validating models in subprocesses...")
exceptions = [] # run all validations before raising
onnx_paths = []
for i, model_name in enumerate(models_and_onnx_configs.keys()):
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
onnx_model_path = (
output_dir.joinpath(onnx_files_subpaths[i])
if onnx_files_subpaths is not None
else output_dir.joinpath(model_name + ".onnx")
)
onnx_paths.append(onnx_model_path)
try:
# Model validation is done in subprocesses, as ONNX Runtime has the bad habit of
# not releasing memory once an InferenceSession is initialized.
Expand All @@ -168,12 +167,12 @@ def validate_models_outputs(
model_kwargs=model_kwargs,
)
except Exception as e:
exceptions.append(e)
exceptions.append((onnx_model_path, e))

if len(exceptions) != 0:
for i, exception in enumerate(exceptions[:-1]):
logger.error(f"Validation {i} for the model {onnx_paths[i].as_posix()} raised: {exception}")
raise exceptions[-1]
logger.error(f"Validation for the model {exception[0].as_posix()} raised: {exception[1]}")
raise exceptions[-1][1]


def validate_model_outputs(
Expand Down Expand Up @@ -423,9 +422,11 @@ def _run_validation(

if value_failures:
msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures)
raise AtolError(
f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}"
)
atol_msg = f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}"

if isinstance(config, SpeechT5OnnxConfig):
atol_msg += "\nIMPORTANT NOTE: SpeechT5 uses a dropout at inference and the output validation of ONNX Runtime inference vs PyTorch is expected to fail. Reference: https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L727"
raise AtolError(atol_msg)


class ValidationProcess(mp.Process):
Expand Down Expand Up @@ -526,7 +527,7 @@ def export_pytorch(

with torch.no_grad():
model.config.return_dict = True
model.eval()
model = model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
Expand Down Expand Up @@ -560,26 +561,25 @@ def remap(value):
if is_torch_less_than_1_11:
raise RuntimeError("The ONNX export using the PyTorch framework is only supported for v1.11+")
else:
with torch.no_grad():
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)

inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dict(chain(inputs.items(), config.outputs.items())),
do_constant_folding=True,
opset_version=opset,
)
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)

inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dict(chain(inputs.items(), config.outputs.items())),
do_constant_folding=True,
opset_version=opset,
)

# check if external data was exported
# TODO: this is quite inefficient as we load in memory if models are <2GB without external data
Expand Down
48 changes: 35 additions & 13 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,6 @@ class T5OnnxConfig(TextSeq2SeqOnnxConfig):
num_attention_heads="num_heads",
encoder_num_layers="num_layers",
decoder_num_layers="num_decoder_layers",
key_value_dim="d_kv",
allow_new=True,
)

def generate_dummy_inputs_for_validation(
Expand Down Expand Up @@ -1183,16 +1181,22 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943),
# so we won't support for now.
NORMALIZED_CONFIG_CLASS = NormalizedConfig
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(decoder_num_layers="decoder_layers")
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
hidden_size="hidden_size",
num_attention_heads="encoder_attention_heads", # TODO: bugged in case encoder and decoder have different number of heads
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
allow_new=True,
)

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTextInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
T5DummySeq2SeqPastKeyValuesGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummySpeechT5InputGenerator,
)
DUMMY_PKV_GENERATOR_CLASS = T5DummySeq2SeqPastKeyValuesGenerator

# TODO: DO NOT CUT OUTPUT_SEQUENCE LENGTH WITH PAST!!!!!
DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator

VARIANTS = {
"with-past": "The export follows the Transformers implementation using the KV cache, with the following components exported:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.",
Expand Down Expand Up @@ -1240,7 +1244,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"}

if self.variant == "with-past" and self.use_past_in_inputs:
# TODO: check PKV shape
self.add_past_key_values(common_inputs, direction="inputs")
elif self.is_postnet_and_vocoder:
common_inputs["spectrogram"] = {0: "n_spectrums x reduction_factor"}
Expand Down Expand Up @@ -1281,11 +1284,7 @@ def patch_model_for_export(

@property
def torch_to_onnx_input_map(self) -> Dict[str, str]:
return {
# "decoder_input_ids": "input_ids",
"encoder_outputs": "encoder_hidden_states",
# "attention_mask": "encoder_attention_mask",
}
return {"encoder_outputs": "encoder_hidden_states"}

def overwrite_shape_and_generate_input(
self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict
Expand All @@ -1296,6 +1295,29 @@ def overwrite_shape_and_generate_input(
)
return dummy_input

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_decoder_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_decoder_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.decoder_num_layers):
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {2: decoder_sequence_name}

if (
self.is_merged is True
or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs)
or direction == "inputs"
):
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {2: "encoder_sequence_length_out"}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {2: "encoder_sequence_length_out"}


class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
Expand Down
26 changes: 21 additions & 5 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

model.vocoder = model_kwargs["vocoder_model"]
model.vocoder = model_kwargs["vocoder_model"].eval()

def patched_forward(
input_ids=None,
Expand All @@ -390,7 +390,7 @@ def patched_forward(
encoder_out[0].shape[1], encoder_attention_mask
)

return {
result = {
"encoder_outputs": encoder_out.last_hidden_state,
"encoder_attention_mask": encoder_attention_mask,
}
Expand Down Expand Up @@ -431,11 +431,11 @@ def patched_forward(
# Predict the probability that this is the stop token.
prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))

return {
result = {
"output_sequence_out": output_sequence,
"spectrum": spectrum,
"prob": prob,
# TODO: PKV here
"past_key_values": past_key_values,
}
elif self.real_config.is_postnet_and_vocoder:
# NOTE: the following concatenation is expected to be handled outside of the ONNX:
Expand All @@ -446,8 +446,24 @@ def patched_forward(

waveform = model.vocoder(spectrogram)

return {"waveform": waveform}
result = {"waveform": waveform}
else:
raise ValueError("Should not happen")

# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filterd_outputs = {}
for name, value in result.items():
if name != "past_key_values":
filterd_outputs[name] = value
else:
if self.real_config._behavior == "decoder" and (
self.real_config.is_merged or not self.real_config.use_past_in_inputs
):
filterd_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])

return filterd_outputs

self.patched_forward = patched_forward
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def get_speecht5_models_for_export(
preprocessors=config._preprocessors,
is_postnet_and_vocoder=True,
)
postnet_and_vocoder_onnx_config.variant = config.variant
models_for_export["decoder_postnet_and_vocoder"] = (
models_for_export["decoder_postnet_and_vocoder"],
postnet_and_vocoder_onnx_config,
Expand Down

0 comments on commit b107b2d

Please sign in to comment.