diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 16a18afc552..4bc96c3dd63 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -349,6 +349,13 @@ def main_export( library_name=library_name, ) + if model.config.model_type == "whisper": + import torch + model = model.eval() + with torch.no_grad(): + for i in range(len(model.model.decoder.layers)): + model.model.decoder.layers[i].encoder_attn = torch.jit.script(model.model.decoder.layers[i].encoder_attn) + custom_architecture = False is_stable_diffusion = "stable-diffusion" in task model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") @@ -574,10 +581,12 @@ def main_export( logger.warning( f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {output.as_posix()}" ) + """ except Exception as e: raise Exception( f"An error occured during validation, but the model was saved nonetheless at {output.as_posix()}. Detailed error: {e}." ) + """ def main(): diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 1e2ae99955c..e1e0b82b5ac 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -325,7 +325,7 @@ def fix_dynamic_axes( onnx_inputs[name] = value for name, value in onnx_inputs.items(): - if value.dtype == np.float32 and dtype == "fp16": + if value is not None and (value.dtype == np.float32 and dtype == "fp16"): onnx_inputs[name] = onnx_inputs[name].astype(np.float16) outputs = session.run(None, onnx_inputs) @@ -579,14 +579,15 @@ def __init__( @property def outputs(self) -> Dict[str, Dict[int, str]]: - if not self.use_past_in_inputs: - common_outputs = super().outputs + #if not self.use_past_in_inputs: + # common_outputs = super().outputs + # In the other cases, the sequence_length axis is not dynamic, always of length 1 - elif self.task == "feature-extraction": - common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}}) + if self.task == "feature-extraction": + common_outputs = OrderedDict({"last_hidden_state": {}}) else: - common_outputs = OrderedDict({"logits": {0: "batch_size"}}) - if self.use_past: + common_outputs = OrderedDict({"logits": {}}) + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @@ -602,7 +603,8 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = {} input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] - if self.use_past_in_inputs and self.use_cache_branch is not False: + + if self.use_past_in_inputs: input_names.append("past_key_values") for input_name in input_names: @@ -623,30 +625,30 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): ) # refer to https://github.com/huggingface/optimum/pull/764 - if ( - self.use_past_in_inputs - and self.PAD_ATTENTION_MASK_TO_PAST - and self.use_cache_branch is not False - and "attention_mask" in dummy_inputs - ): - # Obtain the past sequence length from the value instead of the key (Bloom). - past_length = dummy_inputs["past_key_values"][0][1].shape[-2] - - dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( - dummy_inputs["attention_mask"], - desired_length=past_length + 1, - dim=1, - dtype=dummy_inputs["attention_mask"].dtype, - ) - - if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs: - past_length = dummy_inputs["past_key_values"][0][0].shape[2] - dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim( - dummy_inputs["decoder_attention_mask"], - desired_length=past_length + 1, - dim=1, - dtype=dummy_inputs["decoder_attention_mask"].dtype, - ) + # if ( + # self.use_past_in_inputs + # and self.PAD_ATTENTION_MASK_TO_PAST + # and self.use_cache_branch is not False + # and "attention_mask" in dummy_inputs + # ): + # # Obtain the past sequence length from the value instead of the key (Bloom). + # past_length = dummy_inputs["past_key_values"][0][1].shape[-2] + + # dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( + # dummy_inputs["attention_mask"], + # desired_length=past_length + 1, + # dim=1, + # dtype=dummy_inputs["attention_mask"].dtype, + # ) + + # if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs: + # past_length = dummy_inputs["past_key_values"][0][0].shape[2] + # dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim( + # dummy_inputs["decoder_attention_mask"], + # desired_length=past_length + 1, + # dim=1, + # dtype=dummy_inputs["decoder_attention_mask"].dtype, + # ) return dummy_inputs @@ -702,8 +704,8 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire name = "present" for i in range(self._normalized_config.num_layers): - inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name} - inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.key"] = {} # static shapes + inputs_or_outputs[f"{name}.{i}.value"] = {} def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.key"] = t[0] @@ -821,7 +823,16 @@ def with_behavior( @property def outputs(self) -> Dict[str, Dict[int, str]]: - common_outputs = super(OnnxConfigWithPast, self).outputs + # In the other cases, the sequence_length axis is not dynamic, always of length 1 + if self.task == "feature-extraction": + common_outputs = OrderedDict({"last_hidden_state": {}}) + else: + common_outputs = OrderedDict({"logits": {}}) + + # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. + self.add_past_key_values(common_outputs, direction="outputs") + + """ # Renaming the outputs axes properly. for name, axes_names in common_outputs.items(): if self._behavior is ConfigBehavior.ENCODER or "encoder" in name: @@ -840,10 +851,12 @@ def outputs(self) -> Dict[str, Dict[int, str]]: else: new_axes_names[axis_idx] = axis_name common_outputs[name] = new_axes_names + if self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") + """ return common_outputs @@ -859,19 +872,17 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire name = "present" for i in range(self._normalized_config.decoder_num_layers): - inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch_size", 2: decoder_sequence_name} - inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch_size", 2: decoder_sequence_name} - + inputs_or_outputs[f"{name}.{i}.decoder.key"] = {} + inputs_or_outputs[f"{name}.{i}.decoder.value"] = {} + if ( - self.is_merged is True - or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) - or direction == "inputs" + self._behavior is ConfigBehavior.DECODER ): # TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export() # time we have currently no case to check whether we will merge at a later step or not (self.is_merged is # not yet set at this time) - inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length_out"} - inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length_out"} + inputs_or_outputs[f"{name}.{i}.encoder.key"] = {} + inputs_or_outputs[f"{name}.{i}.encoder.value"] = {} def flatten_past_key_values(self, flattened_output, name, idx, t): if len(t) not in [2, 4]: diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 9259ad853da..563a6452290 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -92,13 +92,13 @@ def __init__( @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.use_past_in_inputs: - common_inputs = {"input_ids": {0: "batch_size"}} + common_inputs = {"input_ids": {}} self.add_past_key_values(common_inputs, direction="inputs") - common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} + common_inputs["attention_mask"] = {} else: common_inputs = { - "input_ids": {0: "batch_size", 1: "sequence_length"}, - "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "input_ids": {}, + "attention_mask": {}, } return common_inputs @@ -109,7 +109,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: else: # in the merged case, we need to allow the `sequence_length` to be variable, as it is not 1 # during the first pass without past key values - common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}) + common_outputs = OrderedDict({"logits": {}}) self.add_past_key_values(common_outputs, direction="outputs") return common_outputs @@ -125,32 +125,32 @@ def post_process_exported_models( path, models_and_onnx_configs, onnx_files_subpaths ) - # Attempt to merge only if the decoder-only was exported separately without/with past - if self.use_past is True and len(models_and_onnx_configs) == 2: - decoder_path = Path(path, onnx_files_subpaths[0]) - decoder_with_past_path = Path(path, onnx_files_subpaths[1]) - decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") - try: - merge_decoders( - decoder=decoder_path, - decoder_with_past=decoder_with_past_path, - save_path=decoder_merged_path, - ) - except Exception as e: - raise Exception(f"Unable to merge decoders. Detailed error: {e}") - - # In order to do the validation of the two branches on the same file - onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name] - - # We validate the two branches of the decoder model then - models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True - models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False - - # Past key values won't be generated by default, but added in the input - models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True - - models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True - models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True + # # Attempt to merge only if the decoder-only was exported separately without/with past + # if self.use_past is True and len(models_and_onnx_configs) == 2: + # decoder_path = Path(path, onnx_files_subpaths[0]) + # decoder_with_past_path = Path(path, onnx_files_subpaths[1]) + # decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") + # try: + # merge_decoders( + # decoder=decoder_path, + # decoder_with_past=decoder_with_past_path, + # save_path=decoder_merged_path, + # ) + # except Exception as e: + # raise Exception(f"Unable to merge decoders. Detailed error: {e}") + + # # In order to do the validation of the two branches on the same file + # onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name] + + # # We validate the two branches of the decoder model then + # models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True + # models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False + + # # Past key values won't be generated by default, but added in the input + # models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True + + # models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True + # models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True return models_and_onnx_configs, onnx_files_subpaths @@ -165,9 +165,9 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802 if not self.no_position_ids and self.task == "text-generation": if self.use_past_in_inputs: - common_inputs["position_ids"] = {0: "batch_size"} + common_inputs["position_ids"] = {} else: - common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} + common_inputs["position_ids"] = {} return common_inputs @@ -272,6 +272,7 @@ class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast): DummyAudioInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, + DummyTextInputGenerator, ) @property @@ -279,19 +280,18 @@ def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = {} if self._behavior is not ConfigBehavior.DECODER: - common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"} + common_inputs["input_features"] = {} if self._behavior is not ConfigBehavior.ENCODER: - if self.use_past_in_inputs: - common_inputs["decoder_input_ids"] = {0: "batch_size"} - else: - common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + common_inputs["decoder_input_ids"] = {} - if self.use_past_in_inputs: - self.add_past_key_values(common_inputs, direction="inputs") + self.add_past_key_values(common_inputs, direction="inputs") + + common_inputs["decoder_attention_mask"] = {} + common_inputs["position_ids"] = {} if self._behavior is ConfigBehavior.DECODER: - common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} + common_inputs["encoder_outputs"] = {} return common_inputs diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index f637da07804..7658afdc419 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -151,24 +151,23 @@ def validate_models_outputs( else output_dir.joinpath(model_name + ".onnx") ) onnx_paths.append(onnx_model_path) - try: - # Model validation is done in subprocesses, as ONNX Runtime has the bad habit of - # not releasing memory once an InferenceSession is initialized. - # Reference: https://github.com/huggingface/optimum/pull/1115 - validate_model_outputs( - config=sub_onnx_config, - reference_model=submodel, - onnx_model=onnx_model_path, - onnx_named_outputs=onnx_named_outputs[i], - atol=atol, - input_shapes=input_shapes, - device=device, - dtype=dtype, - use_subprocess=use_subprocess, - model_kwargs=model_kwargs, - ) - except Exception as e: - exceptions.append(e) + # Model validation is done in subprocesses, as ONNX Runtime has the bad habit of + # not releasing memory once an InferenceSession is initialized. + # Reference: https://github.com/huggingface/optimum/pull/1115 + validate_model_outputs( + config=sub_onnx_config, + reference_model=submodel, + onnx_model=onnx_model_path, + onnx_named_outputs=onnx_named_outputs[i], + atol=atol, + input_shapes=input_shapes, + device=device, + dtype=dtype, + use_subprocess=use_subprocess, + model_kwargs=model_kwargs, + ) + #except Exception as e: + # exceptions.append(e) if len(exceptions) != 0: for i, exception in enumerate(exceptions[:-1]): @@ -528,13 +527,6 @@ def export_pytorch( model.config.return_dict = True model.eval() - # Check if we need to override certain configuration item - if config.values_override is not None: - logger.info(f"Overriding {len(config.values_override)} configuration item(s)") - for override_config_key, override_config_value in config.values_override.items(): - logger.info(f"\t- {override_config_key} -> {override_config_value}") - setattr(model.config, override_config_key, override_config_value) - if input_shapes is None: input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES @@ -555,6 +547,7 @@ def remap(value): dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs) + print("---------------- EXPORT") # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, # so we check the torch version for backwards compatibility if is_torch_less_than_1_11: @@ -567,11 +560,28 @@ def remap(value): input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) + print("input_names", input_names) + print("output_names", output_names) + print("dummy_inputs keys", dummy_inputs.keys()) + + for name, inp in dummy_inputs.items(): + if isinstance(inp, torch.Tensor): + print(name, inp.shape) + else: + print(name, type(inp)) + + if model.config.model_type == "whisper" and config._behavior == "decoder": + dummy_inputs = (dummy_inputs["decoder_input_ids"], dummy_inputs["decoder_attention_mask"], dummy_inputs["encoder_outputs"][0], dummy_inputs["past_key_values"], dummy_inputs["position_ids"]) + + model = torch.jit.trace(model, dummy_inputs) + else: + dummy_inputs = (dummy_inputs,) + # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. onnx_export( model, - (dummy_inputs,), + dummy_inputs, f=output.as_posix(), input_names=input_names, output_names=output_names, @@ -879,6 +889,7 @@ def export( "You either provided a PyTorch model with only TensorFlow installed, or a TensorFlow model with only PyTorch installed." ) + print("------ FIXING") if not disable_dynamic_axes_fix: config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype) return export_output diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 401d995fdc7..915ee31403f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -210,7 +210,7 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig -class OPTOnnxConfig(TextDecoderOnnxConfig): +class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # OPT does not require position_ids input. DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig @@ -1127,22 +1127,37 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig ATOL_FOR_VALIDATION = 1e-3 + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=preprocessors, + ) + if self._behavior is ConfigBehavior.ENCODER: + self.use_past_in_inputs = False + else: + self.use_past_in_inputs = True + @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs - if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: - common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" return common_inputs - @property - def outputs(self) -> Dict[str, Dict[int, str]]: - common_outputs = super().outputs - if self._behavior is ConfigBehavior.ENCODER: - # For Whisper, we need to name the second axis as encoder_sequence_length / 2 as the axis name is used for - # dummy input generation - common_outputs["last_hidden_state"][1] = f"{common_outputs['last_hidden_state'][1]} / 2" - return common_outputs - class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index e6b50b6dc08..b8204fe89c0 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -102,7 +102,7 @@ def __init__( else: self.real_config = config - allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past + allow_past_in_outputs = (hasattr(self.real_config, "use_past") and self.real_config.use_past) @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): @@ -193,13 +193,15 @@ def patched_forward(*args, **kwargs): outputs = self.orig_forward(*args, **kwargs) + """ # Filter out cross attention past key values output from the decoder using KV cache, as they are constants. filterd_outputs = {} for name, value in outputs.items(): + # filterd_outputs[name] = value onnx_output_name = config.torch_to_onnx_output_map.get(name, name) if ( onnx_output_name in config.outputs - or (allow_past_in_outputs and name.startswith("past_key_values")) + or name.startswith("past_key_values") or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) ): if name != "past_key_values": @@ -209,16 +211,12 @@ def patched_forward(*args, **kwargs): else: filterd_outputs[name] = value else: - if self.real_config._behavior == "monolith" or ( - self.real_config._behavior == "decoder" - and (self.real_config.is_merged or not self.real_config.use_past_in_inputs) - ): - filterd_outputs[name] = value - elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: - # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. - filterd_outputs[name] = tuple([v[:2] for v in value]) + # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. + # filterd_outputs[name] = tuple([v[:2] for v in value]) + filterd_outputs[name] = value + """ - return filterd_outputs + return outputs self.patched_forward = patched_forward diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 5882972d758..9beb137246a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1489,6 +1489,7 @@ def infer_library_from_model( Returns: `str`: The library name automatically detected from the model repo. """ + return "transformers" # working offline if library_name is not None: return library_name diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 227c12315d9..103acd79abd 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -360,9 +360,16 @@ def __init__( def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): min_value = 0 max_value = 2 if input_name != "input_ids" else self.vocab_size - shape = [self.batch_size, self.sequence_length] + + # TODO: fix + if input_name == "decoder_attention_mask": + shape = [self.batch_size, 448] # TODO: fix to max_length for whisper + else: + shape = [self.batch_size, self.sequence_length] + if self.task == "multiple-choice": shape = [self.batch_size, self.num_choices, self.sequence_length] + return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype) @@ -414,7 +421,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int if input_name in ["encoder_outputs", "encoder_hidden_states"]: return ( self.random_float_tensor( - shape=[self.batch_size, self.sequence_length, self.hidden_size], + shape=[self.batch_size, 1500, self.hidden_size], min_value=0, max_value=1, framework=framework, @@ -492,6 +499,8 @@ def __init__( **kwargs, ): self.normalized_config = normalized_config + self.batch_size = 1 + """ if random_batch_size_range: low, high = random_batch_size_range self.batch_size = random.randint(low, high) @@ -505,6 +514,10 @@ def __init__( self.encoder_sequence_length = ( self.sequence_length if encoder_sequence_length is None else encoder_sequence_length ) + """ + self.encoder_sequence_length = 1500 + self.sequence_length = 448 # TODO: fix to max_length for whisper + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): encoder_shape = ( @@ -519,7 +532,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int self.sequence_length, self.normalized_config.hidden_size // self.normalized_config.decoder_num_attention_heads, ) - return [ + return tuple( ( self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), @@ -527,7 +540,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), ) for _ in range(self.normalized_config.decoder_num_layers) - ] + ) # TODO: should it just be merged to DummyTextInputGenerator?