From 9ac3402019a01a158c04d427e52b7e74fc211fb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Tue, 3 Oct 2023 14:37:10 +0200 Subject: [PATCH] add ugly code --- optimum/exporters/onnx/__main__.py | 7 ++++ optimum/exporters/onnx/base.py | 45 ++++++++++++++++--------- optimum/exporters/onnx/config.py | 16 ++++----- optimum/exporters/onnx/convert.py | 28 ++++++++++----- optimum/exporters/onnx/model_configs.py | 11 ------ optimum/exporters/onnx/model_patcher.py | 20 +++++------ optimum/exporters/tasks.py | 1 + optimum/utils/input_generators.py | 20 ++++++++--- 8 files changed, 91 insertions(+), 57 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 16a18afc552..e61722fefd1 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -348,6 +348,11 @@ def main_export( device=device, library_name=library_name, ) + import torch + model = model.eval() + with torch.no_grad(): + for i in range(len(model.model.decoder.layers)): + model.model.decoder.layers[i].encoder_attn = torch.jit.script(model.model.decoder.layers[i].encoder_attn) custom_architecture = False is_stable_diffusion = "stable-diffusion" in task @@ -574,10 +579,12 @@ def main_export( logger.warning( f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {output.as_posix()}" ) + """ except Exception as e: raise Exception( f"An error occured during validation, but the model was saved nonetheless at {output.as_posix()}. Detailed error: {e}." ) + """ def main(): diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 1e2ae99955c..ab15215200f 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -325,7 +325,7 @@ def fix_dynamic_axes( onnx_inputs[name] = value for name, value in onnx_inputs.items(): - if value.dtype == np.float32 and dtype == "fp16": + if value is not None and (value.dtype == np.float32 and dtype == "fp16"): onnx_inputs[name] = onnx_inputs[name].astype(np.float16) outputs = session.run(None, onnx_inputs) @@ -579,14 +579,16 @@ def __init__( @property def outputs(self) -> Dict[str, Dict[int, str]]: - if not self.use_past_in_inputs: - common_outputs = super().outputs + #if not self.use_past_in_inputs: + # common_outputs = super().outputs + self.use_past_in_inputs = True + # In the other cases, the sequence_length axis is not dynamic, always of length 1 - elif self.task == "feature-extraction": + if self.task == "feature-extraction": common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}}) else: - common_outputs = OrderedDict({"logits": {0: "batch_size"}}) - if self.use_past: + common_outputs = OrderedDict({"logits": {}}) + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @@ -602,7 +604,9 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = {} input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] - if self.use_past_in_inputs and self.use_cache_branch is not False: + + print("self._behavior", self._behavior) + if self._behavior is not ConfigBehavior.ENCODER: input_names.append("past_key_values") for input_name in input_names: @@ -821,7 +825,16 @@ def with_behavior( @property def outputs(self) -> Dict[str, Dict[int, str]]: - common_outputs = super(OnnxConfigWithPast, self).outputs + # In the other cases, the sequence_length axis is not dynamic, always of length 1 + if self.task == "feature-extraction": + common_outputs = OrderedDict({"last_hidden_state": {}}) + else: + common_outputs = OrderedDict({"logits": {}}) + + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. + self.add_past_key_values(common_outputs, direction="outputs") + + """ # Renaming the outputs axes properly. for name, axes_names in common_outputs.items(): if self._behavior is ConfigBehavior.ENCODER or "encoder" in name: @@ -840,10 +853,12 @@ def outputs(self) -> Dict[str, Dict[int, str]]: else: new_axes_names[axis_idx] = axis_name common_outputs[name] = new_axes_names + if self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") + """ return common_outputs @@ -859,19 +874,17 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire name = "present" for i in range(self._normalized_config.decoder_num_layers): - inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch_size", 2: decoder_sequence_name} - inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch_size", 2: decoder_sequence_name} - + inputs_or_outputs[f"{name}.{i}.decoder.key"] = {} + inputs_or_outputs[f"{name}.{i}.decoder.value"] = {} + if ( - self.is_merged is True - or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) - or direction == "inputs" + self._behavior is ConfigBehavior.DECODER ): # TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export() # time we have currently no case to check whether we will merge at a later step or not (self.is_merged is # not yet set at this time) - inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length_out"} - inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length_out"} + inputs_or_outputs[f"{name}.{i}.encoder.key"] = {} + inputs_or_outputs[f"{name}.{i}.encoder.value"] = {} def flatten_past_key_values(self, flattened_output, name, idx, t): if len(t) not in [2, 4]: diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 9259ad853da..e1c216dc1cd 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -272,6 +272,7 @@ class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast): DummyAudioInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, + DummyTextInputGenerator, ) @property @@ -279,19 +280,18 @@ def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = {} if self._behavior is not ConfigBehavior.DECODER: - common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"} + common_inputs["input_features"] = {} if self._behavior is not ConfigBehavior.ENCODER: - if self.use_past_in_inputs: - common_inputs["decoder_input_ids"] = {0: "batch_size"} - else: - common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + common_inputs["decoder_input_ids"] = {} - if self.use_past_in_inputs: - self.add_past_key_values(common_inputs, direction="inputs") + self.add_past_key_values(common_inputs, direction="inputs") + + common_inputs["decoder_attention_mask"] = {} + common_inputs["position_ids"] = {} if self._behavior is ConfigBehavior.DECODER: - common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} + common_inputs["encoder_outputs"] = {} return common_inputs diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index f637da07804..4704e9f9a2b 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -528,13 +528,6 @@ def export_pytorch( model.config.return_dict = True model.eval() - # Check if we need to override certain configuration item - if config.values_override is not None: - logger.info(f"Overriding {len(config.values_override)} configuration item(s)") - for override_config_key, override_config_value in config.values_override.items(): - logger.info(f"\t- {override_config_key} -> {override_config_value}") - setattr(model.config, override_config_key, override_config_value) - if input_shapes is None: input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES @@ -555,6 +548,7 @@ def remap(value): dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs) + print("---------------- EXPORT") # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, # so we check the torch version for backwards compatibility if is_torch_less_than_1_11: @@ -567,11 +561,28 @@ def remap(value): input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) + print("input_names", input_names) + print("output_names", output_names) + print("dummy_inputs keys", dummy_inputs.keys()) + + for name, inp in dummy_inputs.items(): + if isinstance(inp, torch.Tensor): + print(name, inp.shape) + else: + print(name, type(inp)) + + if config._behavior == "decoder": + dummy_inputs = (dummy_inputs["decoder_input_ids"], dummy_inputs["decoder_attention_mask"], dummy_inputs["encoder_outputs"][0], dummy_inputs["past_key_values"], dummy_inputs["position_ids"]) + + model = torch.jit.trace(model, dummy_inputs) + else: + dummy_inputs = (dummy_inputs,) + # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. onnx_export( model, - (dummy_inputs,), + dummy_inputs, f=output.as_posix(), input_names=input_names, output_names=output_names, @@ -879,6 +890,7 @@ def export( "You either provided a PyTorch model with only TensorFlow installed, or a TensorFlow model with only PyTorch installed." ) + print("------ FIXING") if not disable_dynamic_axes_fix: config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype) return export_output diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 401d995fdc7..b7a68a6bf9d 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1130,19 +1130,8 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs - if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: - common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" return common_inputs - @property - def outputs(self) -> Dict[str, Dict[int, str]]: - common_outputs = super().outputs - if self._behavior is ConfigBehavior.ENCODER: - # For Whisper, we need to name the second axis as encoder_sequence_length / 2 as the axis name is used for - # dummy input generation - common_outputs["last_hidden_state"][1] = f"{common_outputs['last_hidden_state'][1]} / 2" - return common_outputs - class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index e6b50b6dc08..634b7d9c4ad 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -193,13 +193,15 @@ def patched_forward(*args, **kwargs): outputs = self.orig_forward(*args, **kwargs) + """ # Filter out cross attention past key values output from the decoder using KV cache, as they are constants. filterd_outputs = {} for name, value in outputs.items(): + # filterd_outputs[name] = value onnx_output_name = config.torch_to_onnx_output_map.get(name, name) if ( onnx_output_name in config.outputs - or (allow_past_in_outputs and name.startswith("past_key_values")) + or name.startswith("past_key_values") or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) ): if name != "past_key_values": @@ -209,16 +211,14 @@ def patched_forward(*args, **kwargs): else: filterd_outputs[name] = value else: - if self.real_config._behavior == "monolith" or ( - self.real_config._behavior == "decoder" - and (self.real_config.is_merged or not self.real_config.use_past_in_inputs) - ): - filterd_outputs[name] = value - elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: - # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. - filterd_outputs[name] = tuple([v[:2] for v in value]) + # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. + # filterd_outputs[name] = tuple([v[:2] for v in value]) + filterd_outputs[name] = value + """ - return filterd_outputs + #print("filterd_outputs", filterd_outputs.keys()) + + return outputs self.patched_forward = patched_forward diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 5882972d758..9beb137246a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1489,6 +1489,7 @@ def infer_library_from_model( Returns: `str`: The library name automatically detected from the model repo. """ + return "transformers" # working offline if library_name is not None: return library_name diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 227c12315d9..e465e1165c2 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -360,7 +360,13 @@ def __init__( def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): min_value = 0 max_value = 2 if input_name != "input_ids" else self.vocab_size - shape = [self.batch_size, self.sequence_length] + + # TODO: fix + if input_name == "decoder_attention_mask": + shape = [self.batch_size, 128] + else: + shape = [self.batch_size, self.sequence_length] + if self.task == "multiple-choice": shape = [self.batch_size, self.num_choices, self.sequence_length] return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype) @@ -414,7 +420,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int if input_name in ["encoder_outputs", "encoder_hidden_states"]: return ( self.random_float_tensor( - shape=[self.batch_size, self.sequence_length, self.hidden_size], + shape=[self.batch_size, 1500, self.hidden_size], min_value=0, max_value=1, framework=framework, @@ -492,6 +498,8 @@ def __init__( **kwargs, ): self.normalized_config = normalized_config + self.batch_size = 1 + """ if random_batch_size_range: low, high = random_batch_size_range self.batch_size = random.randint(low, high) @@ -505,6 +513,10 @@ def __init__( self.encoder_sequence_length = ( self.sequence_length if encoder_sequence_length is None else encoder_sequence_length ) + """ + self.encoder_sequence_length = 1500 + self.sequence_length = 128 + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): encoder_shape = ( @@ -519,7 +531,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int self.sequence_length, self.normalized_config.hidden_size // self.normalized_config.decoder_num_attention_heads, ) - return [ + return tuple( ( self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), @@ -527,7 +539,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), ) for _ in range(self.normalized_config.decoder_num_layers) - ] + ) # TODO: should it just be merged to DummyTextInputGenerator?