From 31026499ac4f6bdb3a7f59de4861b56d67683d90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 5 Oct 2023 17:07:38 +0200 Subject: [PATCH] fix export --- optimum/exporters/onnx/model_configs.py | 43 ++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 6c7828d77ac..d0372f3145e 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -216,9 +216,50 @@ class OPTOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +class LlamaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + **kwargs, + ) + self.num_key_value_heads = normalized_config.num_key_value_heads + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + shape = ( + self.batch_size, + self.num_key_value_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, LlamaDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = LlamaDummyPastKeyValuesGenerator + DEFAULT_ONNX_OPSET = 13 - NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_key_value_heads") + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig class MPTOnnxConfig(TextDecoderOnnxConfig):