From b107b2d74c52d1575127b8d1d3ed65ef0334bf40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 21 Sep 2023 17:36:28 +0200 Subject: [PATCH] working with-past version --- optimum/exporters/onnx/convert.py | 58 ++++++++++++------------- optimum/exporters/onnx/model_configs.py | 48 ++++++++++++++------ optimum/exporters/onnx/model_patcher.py | 26 ++++++++--- optimum/exporters/onnx/utils.py | 1 + 4 files changed, 86 insertions(+), 47 deletions(-) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 821a39cb06b..0b00667e6c8 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -38,6 +38,7 @@ ) from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError from .base import OnnxConfig +from .model_configs import SpeechT5OnnxConfig from .utils import PickableInferenceSession, recursive_to_device @@ -142,7 +143,6 @@ def validate_models_outputs( if use_subprocess: logger.info("Validating models in subprocesses...") exceptions = [] # run all validations before raising - onnx_paths = [] for i, model_name in enumerate(models_and_onnx_configs.keys()): submodel, sub_onnx_config = models_and_onnx_configs[model_name] onnx_model_path = ( @@ -150,7 +150,6 @@ def validate_models_outputs( if onnx_files_subpaths is not None else output_dir.joinpath(model_name + ".onnx") ) - onnx_paths.append(onnx_model_path) try: # Model validation is done in subprocesses, as ONNX Runtime has the bad habit of # not releasing memory once an InferenceSession is initialized. @@ -168,12 +167,12 @@ def validate_models_outputs( model_kwargs=model_kwargs, ) except Exception as e: - exceptions.append(e) + exceptions.append((onnx_model_path, e)) if len(exceptions) != 0: for i, exception in enumerate(exceptions[:-1]): - logger.error(f"Validation {i} for the model {onnx_paths[i].as_posix()} raised: {exception}") - raise exceptions[-1] + logger.error(f"Validation for the model {exception[0].as_posix()} raised: {exception[1]}") + raise exceptions[-1][1] def validate_model_outputs( @@ -423,9 +422,11 @@ def _run_validation( if value_failures: msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures) - raise AtolError( - f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}" - ) + atol_msg = f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}" + + if isinstance(config, SpeechT5OnnxConfig): + atol_msg += "\nIMPORTANT NOTE: SpeechT5 uses a dropout at inference and the output validation of ONNX Runtime inference vs PyTorch is expected to fail. Reference: https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L727" + raise AtolError(atol_msg) class ValidationProcess(mp.Process): @@ -526,7 +527,7 @@ def export_pytorch( with torch.no_grad(): model.config.return_dict = True - model.eval() + model = model.eval() # Check if we need to override certain configuration item if config.values_override is not None: @@ -560,26 +561,25 @@ def remap(value): if is_torch_less_than_1_11: raise RuntimeError("The ONNX export using the PyTorch framework is only supported for v1.11+") else: - with torch.no_grad(): - with config.patch_model_for_export(model, model_kwargs=model_kwargs): - check_dummy_inputs_are_allowed(model, dummy_inputs) - - inputs = config.ordered_inputs(model) - input_names = list(inputs.keys()) - output_names = list(config.outputs.keys()) - - # Export can work with named args but the dict containing named args has to be the last element of the args - # tuple. - onnx_export( - model, - (dummy_inputs,), - f=output.as_posix(), - input_names=input_names, - output_names=output_names, - dynamic_axes=dict(chain(inputs.items(), config.outputs.items())), - do_constant_folding=True, - opset_version=opset, - ) + with config.patch_model_for_export(model, model_kwargs=model_kwargs): + check_dummy_inputs_are_allowed(model, dummy_inputs) + + inputs = config.ordered_inputs(model) + input_names = list(inputs.keys()) + output_names = list(config.outputs.keys()) + + # Export can work with named args but the dict containing named args has to be the last element of the args + # tuple. + onnx_export( + model, + (dummy_inputs,), + f=output.as_posix(), + input_names=input_names, + output_names=output_names, + dynamic_axes=dict(chain(inputs.items(), config.outputs.items())), + do_constant_folding=True, + opset_version=opset, + ) # check if external data was exported # TODO: this is quite inefficient as we load in memory if models are <2GB without external data diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index bbffcec529f..d3ff9944b92 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -327,8 +327,6 @@ class T5OnnxConfig(TextSeq2SeqOnnxConfig): num_attention_heads="num_heads", encoder_num_layers="num_layers", decoder_num_layers="num_decoder_layers", - key_value_dim="d_kv", - allow_new=True, ) def generate_dummy_inputs_for_validation( @@ -1183,16 +1181,22 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): # TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943), # so we won't support for now. - NORMALIZED_CONFIG_CLASS = NormalizedConfig + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(decoder_num_layers="decoder_layers") + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( + hidden_size="hidden_size", + num_attention_heads="encoder_attention_heads", # TODO: bugged in case encoder and decoder have different number of heads + encoder_num_layers="encoder_layers", + decoder_num_layers="decoder_layers", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator, - T5DummySeq2SeqPastKeyValuesGenerator, + DummySeq2SeqPastKeyValuesGenerator, DummySpeechT5InputGenerator, ) - DUMMY_PKV_GENERATOR_CLASS = T5DummySeq2SeqPastKeyValuesGenerator - - # TODO: DO NOT CUT OUTPUT_SEQUENCE LENGTH WITH PAST!!!!! + DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator VARIANTS = { "with-past": "The export follows the Transformers implementation using the KV cache, with the following components exported:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.", @@ -1240,7 +1244,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} if self.variant == "with-past" and self.use_past_in_inputs: - # TODO: check PKV shape self.add_past_key_values(common_inputs, direction="inputs") elif self.is_postnet_and_vocoder: common_inputs["spectrogram"] = {0: "n_spectrums x reduction_factor"} @@ -1281,11 +1284,7 @@ def patch_model_for_export( @property def torch_to_onnx_input_map(self) -> Dict[str, str]: - return { - # "decoder_input_ids": "input_ids", - "encoder_outputs": "encoder_hidden_states", - # "attention_mask": "encoder_attention_mask", - } + return {"encoder_outputs": "encoder_hidden_states"} def overwrite_shape_and_generate_input( self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict @@ -1296,6 +1295,29 @@ def overwrite_shape_and_generate_input( ) return dummy_input + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_decoder_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_decoder_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.decoder_num_layers): + inputs_or_outputs[f"{name}.{i}.decoder.key"] = {2: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.decoder.value"] = {2: decoder_sequence_name} + + if ( + self.is_merged is True + or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) + or direction == "inputs" + ): + inputs_or_outputs[f"{name}.{i}.encoder.key"] = {2: "encoder_sequence_length_out"} + inputs_or_outputs[f"{name}.{i}.encoder.value"] = {2: "encoder_sequence_length_out"} + class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 9679159dd7a..63edefb63cc 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -364,7 +364,7 @@ def __init__( ): super().__init__(config, model, model_kwargs) - model.vocoder = model_kwargs["vocoder_model"] + model.vocoder = model_kwargs["vocoder_model"].eval() def patched_forward( input_ids=None, @@ -390,7 +390,7 @@ def patched_forward( encoder_out[0].shape[1], encoder_attention_mask ) - return { + result = { "encoder_outputs": encoder_out.last_hidden_state, "encoder_attention_mask": encoder_attention_mask, } @@ -431,11 +431,11 @@ def patched_forward( # Predict the probability that this is the stop token. prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) - return { + result = { "output_sequence_out": output_sequence, "spectrum": spectrum, "prob": prob, - # TODO: PKV here + "past_key_values": past_key_values, } elif self.real_config.is_postnet_and_vocoder: # NOTE: the following concatenation is expected to be handled outside of the ONNX: @@ -446,8 +446,24 @@ def patched_forward( waveform = model.vocoder(spectrogram) - return {"waveform": waveform} + result = {"waveform": waveform} else: raise ValueError("Should not happen") + # Filter out cross attention past key values output from the decoder using KV cache, as they are constants. + filterd_outputs = {} + for name, value in result.items(): + if name != "past_key_values": + filterd_outputs[name] = value + else: + if self.real_config._behavior == "decoder" and ( + self.real_config.is_merged or not self.real_config.use_past_in_inputs + ): + filterd_outputs[name] = value + elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: + # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. + filterd_outputs[name] = tuple([v[:2] for v in value]) + + return filterd_outputs + self.patched_forward = patched_forward diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index aa743db3b44..25c50a36dcc 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -439,6 +439,7 @@ def get_speecht5_models_for_export( preprocessors=config._preprocessors, is_postnet_and_vocoder=True, ) + postnet_and_vocoder_onnx_config.variant = config.variant models_for_export["decoder_postnet_and_vocoder"] = ( models_for_export["decoder_postnet_and_vocoder"], postnet_and_vocoder_onnx_config,