From 9ac3402019a01a158c04d427e52b7e74fc211fb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Tue, 3 Oct 2023 14:37:10 +0200 Subject: [PATCH 1/5] add ugly code --- optimum/exporters/onnx/__main__.py | 7 ++++ optimum/exporters/onnx/base.py | 45 ++++++++++++++++--------- optimum/exporters/onnx/config.py | 16 ++++----- optimum/exporters/onnx/convert.py | 28 ++++++++++----- optimum/exporters/onnx/model_configs.py | 11 ------ optimum/exporters/onnx/model_patcher.py | 20 +++++------ optimum/exporters/tasks.py | 1 + optimum/utils/input_generators.py | 20 ++++++++--- 8 files changed, 91 insertions(+), 57 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 16a18afc552..e61722fefd1 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -348,6 +348,11 @@ def main_export( device=device, library_name=library_name, ) + import torch + model = model.eval() + with torch.no_grad(): + for i in range(len(model.model.decoder.layers)): + model.model.decoder.layers[i].encoder_attn = torch.jit.script(model.model.decoder.layers[i].encoder_attn) custom_architecture = False is_stable_diffusion = "stable-diffusion" in task @@ -574,10 +579,12 @@ def main_export( logger.warning( f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {output.as_posix()}" ) + """ except Exception as e: raise Exception( f"An error occured during validation, but the model was saved nonetheless at {output.as_posix()}. Detailed error: {e}." ) + """ def main(): diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 1e2ae99955c..ab15215200f 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -325,7 +325,7 @@ def fix_dynamic_axes( onnx_inputs[name] = value for name, value in onnx_inputs.items(): - if value.dtype == np.float32 and dtype == "fp16": + if value is not None and (value.dtype == np.float32 and dtype == "fp16"): onnx_inputs[name] = onnx_inputs[name].astype(np.float16) outputs = session.run(None, onnx_inputs) @@ -579,14 +579,16 @@ def __init__( @property def outputs(self) -> Dict[str, Dict[int, str]]: - if not self.use_past_in_inputs: - common_outputs = super().outputs + #if not self.use_past_in_inputs: + # common_outputs = super().outputs + self.use_past_in_inputs = True + # In the other cases, the sequence_length axis is not dynamic, always of length 1 - elif self.task == "feature-extraction": + if self.task == "feature-extraction": common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}}) else: - common_outputs = OrderedDict({"logits": {0: "batch_size"}}) - if self.use_past: + common_outputs = OrderedDict({"logits": {}}) + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @@ -602,7 +604,9 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = {} input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] - if self.use_past_in_inputs and self.use_cache_branch is not False: + + print("self._behavior", self._behavior) + if self._behavior is not ConfigBehavior.ENCODER: input_names.append("past_key_values") for input_name in input_names: @@ -821,7 +825,16 @@ def with_behavior( @property def outputs(self) -> Dict[str, Dict[int, str]]: - common_outputs = super(OnnxConfigWithPast, self).outputs + # In the other cases, the sequence_length axis is not dynamic, always of length 1 + if self.task == "feature-extraction": + common_outputs = OrderedDict({"last_hidden_state": {}}) + else: + common_outputs = OrderedDict({"logits": {}}) + + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. + self.add_past_key_values(common_outputs, direction="outputs") + + """ # Renaming the outputs axes properly. for name, axes_names in common_outputs.items(): if self._behavior is ConfigBehavior.ENCODER or "encoder" in name: @@ -840,10 +853,12 @@ def outputs(self) -> Dict[str, Dict[int, str]]: else: new_axes_names[axis_idx] = axis_name common_outputs[name] = new_axes_names + if self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") + """ return common_outputs @@ -859,19 +874,17 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire name = "present" for i in range(self._normalized_config.decoder_num_layers): - inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch_size", 2: decoder_sequence_name} - inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch_size", 2: decoder_sequence_name} - + inputs_or_outputs[f"{name}.{i}.decoder.key"] = {} + inputs_or_outputs[f"{name}.{i}.decoder.value"] = {} + if ( - self.is_merged is True - or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) - or direction == "inputs" + self._behavior is ConfigBehavior.DECODER ): # TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export() # time we have currently no case to check whether we will merge at a later step or not (self.is_merged is # not yet set at this time) - inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length_out"} - inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length_out"} + inputs_or_outputs[f"{name}.{i}.encoder.key"] = {} + inputs_or_outputs[f"{name}.{i}.encoder.value"] = {} def flatten_past_key_values(self, flattened_output, name, idx, t): if len(t) not in [2, 4]: diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 9259ad853da..e1c216dc1cd 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -272,6 +272,7 @@ class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast): DummyAudioInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, + DummyTextInputGenerator, ) @property @@ -279,19 +280,18 @@ def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = {} if self._behavior is not ConfigBehavior.DECODER: - common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"} + common_inputs["input_features"] = {} if self._behavior is not ConfigBehavior.ENCODER: - if self.use_past_in_inputs: - common_inputs["decoder_input_ids"] = {0: "batch_size"} - else: - common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + common_inputs["decoder_input_ids"] = {} - if self.use_past_in_inputs: - self.add_past_key_values(common_inputs, direction="inputs") + self.add_past_key_values(common_inputs, direction="inputs") + + common_inputs["decoder_attention_mask"] = {} + common_inputs["position_ids"] = {} if self._behavior is ConfigBehavior.DECODER: - common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} + common_inputs["encoder_outputs"] = {} return common_inputs diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index f637da07804..4704e9f9a2b 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -528,13 +528,6 @@ def export_pytorch( model.config.return_dict = True model.eval() - # Check if we need to override certain configuration item - if config.values_override is not None: - logger.info(f"Overriding {len(config.values_override)} configuration item(s)") - for override_config_key, override_config_value in config.values_override.items(): - logger.info(f"\t- {override_config_key} -> {override_config_value}") - setattr(model.config, override_config_key, override_config_value) - if input_shapes is None: input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES @@ -555,6 +548,7 @@ def remap(value): dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs) + print("---------------- EXPORT") # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, # so we check the torch version for backwards compatibility if is_torch_less_than_1_11: @@ -567,11 +561,28 @@ def remap(value): input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) + print("input_names", input_names) + print("output_names", output_names) + print("dummy_inputs keys", dummy_inputs.keys()) + + for name, inp in dummy_inputs.items(): + if isinstance(inp, torch.Tensor): + print(name, inp.shape) + else: + print(name, type(inp)) + + if config._behavior == "decoder": + dummy_inputs = (dummy_inputs["decoder_input_ids"], dummy_inputs["decoder_attention_mask"], dummy_inputs["encoder_outputs"][0], dummy_inputs["past_key_values"], dummy_inputs["position_ids"]) + + model = torch.jit.trace(model, dummy_inputs) + else: + dummy_inputs = (dummy_inputs,) + # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. onnx_export( model, - (dummy_inputs,), + dummy_inputs, f=output.as_posix(), input_names=input_names, output_names=output_names, @@ -879,6 +890,7 @@ def export( "You either provided a PyTorch model with only TensorFlow installed, or a TensorFlow model with only PyTorch installed." ) + print("------ FIXING") if not disable_dynamic_axes_fix: config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype) return export_output diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 401d995fdc7..b7a68a6bf9d 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1130,19 +1130,8 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs - if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: - common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" return common_inputs - @property - def outputs(self) -> Dict[str, Dict[int, str]]: - common_outputs = super().outputs - if self._behavior is ConfigBehavior.ENCODER: - # For Whisper, we need to name the second axis as encoder_sequence_length / 2 as the axis name is used for - # dummy input generation - common_outputs["last_hidden_state"][1] = f"{common_outputs['last_hidden_state'][1]} / 2" - return common_outputs - class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index e6b50b6dc08..634b7d9c4ad 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -193,13 +193,15 @@ def patched_forward(*args, **kwargs): outputs = self.orig_forward(*args, **kwargs) + """ # Filter out cross attention past key values output from the decoder using KV cache, as they are constants. filterd_outputs = {} for name, value in outputs.items(): + # filterd_outputs[name] = value onnx_output_name = config.torch_to_onnx_output_map.get(name, name) if ( onnx_output_name in config.outputs - or (allow_past_in_outputs and name.startswith("past_key_values")) + or name.startswith("past_key_values") or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) ): if name != "past_key_values": @@ -209,16 +211,14 @@ def patched_forward(*args, **kwargs): else: filterd_outputs[name] = value else: - if self.real_config._behavior == "monolith" or ( - self.real_config._behavior == "decoder" - and (self.real_config.is_merged or not self.real_config.use_past_in_inputs) - ): - filterd_outputs[name] = value - elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: - # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. - filterd_outputs[name] = tuple([v[:2] for v in value]) + # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. + # filterd_outputs[name] = tuple([v[:2] for v in value]) + filterd_outputs[name] = value + """ - return filterd_outputs + #print("filterd_outputs", filterd_outputs.keys()) + + return outputs self.patched_forward = patched_forward diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 5882972d758..9beb137246a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1489,6 +1489,7 @@ def infer_library_from_model( Returns: `str`: The library name automatically detected from the model repo. """ + return "transformers" # working offline if library_name is not None: return library_name diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 227c12315d9..e465e1165c2 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -360,7 +360,13 @@ def __init__( def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): min_value = 0 max_value = 2 if input_name != "input_ids" else self.vocab_size - shape = [self.batch_size, self.sequence_length] + + # TODO: fix + if input_name == "decoder_attention_mask": + shape = [self.batch_size, 128] + else: + shape = [self.batch_size, self.sequence_length] + if self.task == "multiple-choice": shape = [self.batch_size, self.num_choices, self.sequence_length] return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype) @@ -414,7 +420,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int if input_name in ["encoder_outputs", "encoder_hidden_states"]: return ( self.random_float_tensor( - shape=[self.batch_size, self.sequence_length, self.hidden_size], + shape=[self.batch_size, 1500, self.hidden_size], min_value=0, max_value=1, framework=framework, @@ -492,6 +498,8 @@ def __init__( **kwargs, ): self.normalized_config = normalized_config + self.batch_size = 1 + """ if random_batch_size_range: low, high = random_batch_size_range self.batch_size = random.randint(low, high) @@ -505,6 +513,10 @@ def __init__( self.encoder_sequence_length = ( self.sequence_length if encoder_sequence_length is None else encoder_sequence_length ) + """ + self.encoder_sequence_length = 1500 + self.sequence_length = 128 + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): encoder_shape = ( @@ -519,7 +531,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int self.sequence_length, self.normalized_config.hidden_size // self.normalized_config.decoder_num_attention_heads, ) - return [ + return tuple( ( self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), @@ -527,7 +539,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), ) for _ in range(self.normalized_config.decoder_num_layers) - ] + ) # TODO: should it just be merged to DummyTextInputGenerator? From 21f3b07f87acd352e00c108f727408d5574bda4a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 5 Oct 2023 07:08:06 +0000 Subject: [PATCH 2/5] guard on whisper --- optimum/exporters/onnx/__main__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index e61722fefd1..4bc96c3dd63 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -348,11 +348,13 @@ def main_export( device=device, library_name=library_name, ) - import torch - model = model.eval() - with torch.no_grad(): - for i in range(len(model.model.decoder.layers)): - model.model.decoder.layers[i].encoder_attn = torch.jit.script(model.model.decoder.layers[i].encoder_attn) + + if model.config.model_type == "whisper": + import torch + model = model.eval() + with torch.no_grad(): + for i in range(len(model.model.decoder.layers)): + model.model.decoder.layers[i].encoder_attn = torch.jit.script(model.model.decoder.layers[i].encoder_attn) custom_architecture = False is_stable_diffusion = "stable-diffusion" in task From 1a972b3cb79271f991e85aef49b8fdbccee878a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 5 Oct 2023 10:14:04 +0200 Subject: [PATCH 3/5] working whisper & opt stati --- optimum/exporters/onnx/base.py | 58 ++++++++++++------------- optimum/exporters/onnx/config.py | 14 +++--- optimum/exporters/onnx/convert.py | 37 ++++++++-------- optimum/exporters/onnx/model_configs.py | 28 +++++++++++- optimum/exporters/onnx/model_patcher.py | 4 +- optimum/utils/input_generators.py | 1 + 6 files changed, 82 insertions(+), 60 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index ab15215200f..e1e0b82b5ac 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -581,11 +581,10 @@ def __init__( def outputs(self) -> Dict[str, Dict[int, str]]: #if not self.use_past_in_inputs: # common_outputs = super().outputs - self.use_past_in_inputs = True # In the other cases, the sequence_length axis is not dynamic, always of length 1 if self.task == "feature-extraction": - common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}}) + common_outputs = OrderedDict({"last_hidden_state": {}}) else: common_outputs = OrderedDict({"logits": {}}) @@ -605,8 +604,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = {} input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] - print("self._behavior", self._behavior) - if self._behavior is not ConfigBehavior.ENCODER: + if self.use_past_in_inputs: input_names.append("past_key_values") for input_name in input_names: @@ -627,30 +625,30 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): ) # refer to https://github.com/huggingface/optimum/pull/764 - if ( - self.use_past_in_inputs - and self.PAD_ATTENTION_MASK_TO_PAST - and self.use_cache_branch is not False - and "attention_mask" in dummy_inputs - ): - # Obtain the past sequence length from the value instead of the key (Bloom). - past_length = dummy_inputs["past_key_values"][0][1].shape[-2] - - dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( - dummy_inputs["attention_mask"], - desired_length=past_length + 1, - dim=1, - dtype=dummy_inputs["attention_mask"].dtype, - ) - - if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs: - past_length = dummy_inputs["past_key_values"][0][0].shape[2] - dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim( - dummy_inputs["decoder_attention_mask"], - desired_length=past_length + 1, - dim=1, - dtype=dummy_inputs["decoder_attention_mask"].dtype, - ) + # if ( + # self.use_past_in_inputs + # and self.PAD_ATTENTION_MASK_TO_PAST + # and self.use_cache_branch is not False + # and "attention_mask" in dummy_inputs + # ): + # # Obtain the past sequence length from the value instead of the key (Bloom). + # past_length = dummy_inputs["past_key_values"][0][1].shape[-2] + + # dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( + # dummy_inputs["attention_mask"], + # desired_length=past_length + 1, + # dim=1, + # dtype=dummy_inputs["attention_mask"].dtype, + # ) + + # if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs: + # past_length = dummy_inputs["past_key_values"][0][0].shape[2] + # dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim( + # dummy_inputs["decoder_attention_mask"], + # desired_length=past_length + 1, + # dim=1, + # dtype=dummy_inputs["decoder_attention_mask"].dtype, + # ) return dummy_inputs @@ -706,8 +704,8 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire name = "present" for i in range(self._normalized_config.num_layers): - inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name} - inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.key"] = {} # static shapes + inputs_or_outputs[f"{name}.{i}.value"] = {} def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.key"] = t[0] diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index e1c216dc1cd..17daad5b6bf 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -92,13 +92,13 @@ def __init__( @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.use_past_in_inputs: - common_inputs = {"input_ids": {0: "batch_size"}} + common_inputs = {"input_ids": {}} self.add_past_key_values(common_inputs, direction="inputs") - common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} + common_inputs["attention_mask"] = {} else: common_inputs = { - "input_ids": {0: "batch_size", 1: "sequence_length"}, - "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "input_ids": {}, + "attention_mask": {}, } return common_inputs @@ -109,7 +109,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: else: # in the merged case, we need to allow the `sequence_length` to be variable, as it is not 1 # during the first pass without past key values - common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}) + common_outputs = OrderedDict({"logits": {}}) self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @@ -165,9 +165,9 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802 if not self.no_position_ids and self.task == "text-generation": if self.use_past_in_inputs: - common_inputs["position_ids"] = {0: "batch_size"} + common_inputs["position_ids"] = {} else: - common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} + common_inputs["position_ids"] = {} return common_inputs diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 4704e9f9a2b..7658afdc419 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -151,24 +151,23 @@ def validate_models_outputs( else output_dir.joinpath(model_name + ".onnx") ) onnx_paths.append(onnx_model_path) - try: - # Model validation is done in subprocesses, as ONNX Runtime has the bad habit of - # not releasing memory once an InferenceSession is initialized. - # Reference: https://github.com/huggingface/optimum/pull/1115 - validate_model_outputs( - config=sub_onnx_config, - reference_model=submodel, - onnx_model=onnx_model_path, - onnx_named_outputs=onnx_named_outputs[i], - atol=atol, - input_shapes=input_shapes, - device=device, - dtype=dtype, - use_subprocess=use_subprocess, - model_kwargs=model_kwargs, - ) - except Exception as e: - exceptions.append(e) + # Model validation is done in subprocesses, as ONNX Runtime has the bad habit of + # not releasing memory once an InferenceSession is initialized. + # Reference: https://github.com/huggingface/optimum/pull/1115 + validate_model_outputs( + config=sub_onnx_config, + reference_model=submodel, + onnx_model=onnx_model_path, + onnx_named_outputs=onnx_named_outputs[i], + atol=atol, + input_shapes=input_shapes, + device=device, + dtype=dtype, + use_subprocess=use_subprocess, + model_kwargs=model_kwargs, + ) + #except Exception as e: + # exceptions.append(e) if len(exceptions) != 0: for i, exception in enumerate(exceptions[:-1]): @@ -571,7 +570,7 @@ def remap(value): else: print(name, type(inp)) - if config._behavior == "decoder": + if model.config.model_type == "whisper" and config._behavior == "decoder": dummy_inputs = (dummy_inputs["decoder_input_ids"], dummy_inputs["decoder_attention_mask"], dummy_inputs["encoder_outputs"][0], dummy_inputs["past_key_values"], dummy_inputs["position_ids"]) model = torch.jit.trace(model, dummy_inputs) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index b7a68a6bf9d..915ee31403f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -210,7 +210,7 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig -class OPTOnnxConfig(TextDecoderOnnxConfig): +class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # OPT does not require position_ids input. DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig @@ -1127,6 +1127,32 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig ATOL_FOR_VALIDATION = 1e-3 + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=preprocessors, + ) + if self._behavior is ConfigBehavior.ENCODER: + self.use_past_in_inputs = False + else: + self.use_past_in_inputs = True + @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 634b7d9c4ad..b8204fe89c0 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -102,7 +102,7 @@ def __init__( else: self.real_config = config - allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past + allow_past_in_outputs = (hasattr(self.real_config, "use_past") and self.real_config.use_past) @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): @@ -216,8 +216,6 @@ def patched_forward(*args, **kwargs): filterd_outputs[name] = value """ - #print("filterd_outputs", filterd_outputs.keys()) - return outputs self.patched_forward = patched_forward diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index e465e1165c2..5aca66de4a8 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -369,6 +369,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int if self.task == "multiple-choice": shape = [self.batch_size, self.num_choices, self.sequence_length] + return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype) From 899db7789cd64e4d2ebcce4dc532698022f0a4c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 5 Oct 2023 10:30:34 +0200 Subject: [PATCH 4/5] nit --- optimum/exporters/onnx/config.py | 52 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 17daad5b6bf..563a6452290 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -125,32 +125,32 @@ def post_process_exported_models( path, models_and_onnx_configs, onnx_files_subpaths ) - # Attempt to merge only if the decoder-only was exported separately without/with past - if self.use_past is True and len(models_and_onnx_configs) == 2: - decoder_path = Path(path, onnx_files_subpaths[0]) - decoder_with_past_path = Path(path, onnx_files_subpaths[1]) - decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") - try: - merge_decoders( - decoder=decoder_path, - decoder_with_past=decoder_with_past_path, - save_path=decoder_merged_path, - ) - except Exception as e: - raise Exception(f"Unable to merge decoders. Detailed error: {e}") - - # In order to do the validation of the two branches on the same file - onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name] - - # We validate the two branches of the decoder model then - models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True - models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False - - # Past key values won't be generated by default, but added in the input - models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True - - models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True - models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True + # # Attempt to merge only if the decoder-only was exported separately without/with past + # if self.use_past is True and len(models_and_onnx_configs) == 2: + # decoder_path = Path(path, onnx_files_subpaths[0]) + # decoder_with_past_path = Path(path, onnx_files_subpaths[1]) + # decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") + # try: + # merge_decoders( + # decoder=decoder_path, + # decoder_with_past=decoder_with_past_path, + # save_path=decoder_merged_path, + # ) + # except Exception as e: + # raise Exception(f"Unable to merge decoders. Detailed error: {e}") + + # # In order to do the validation of the two branches on the same file + # onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name] + + # # We validate the two branches of the decoder model then + # models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True + # models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False + + # # Past key values won't be generated by default, but added in the input + # models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True + + # models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True + # models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True return models_and_onnx_configs, onnx_files_subpaths From 841a7ce650efeb9e137c5c9ec38b42f22d37a798 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:06:16 +0200 Subject: [PATCH 5/5] update --- optimum/utils/input_generators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 5aca66de4a8..103acd79abd 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -363,7 +363,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int # TODO: fix if input_name == "decoder_attention_mask": - shape = [self.batch_size, 128] + shape = [self.batch_size, 448] # TODO: fix to max_length for whisper else: shape = [self.batch_size, self.sequence_length] @@ -516,7 +516,7 @@ def __init__( ) """ self.encoder_sequence_length = 1500 - self.sequence_length = 128 + self.sequence_length = 448 # TODO: fix to max_length for whisper def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):