From c3b584aef57b844e96da1c5d4ef614e26cedcc88 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 6 Oct 2023 14:26:46 +0200 Subject: [PATCH] fix after #1432 merged --- optimum/exporters/onnx/model_configs.py | 44 ++----------------------- optimum/onnxruntime/base.py | 4 +-- optimum/utils/input_generators.py | 4 +-- 3 files changed, 7 insertions(+), 45 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 31d56411a5a..b25bfe156af 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -43,6 +43,7 @@ NormalizedVisionConfig, logging, ) + from ...utils.normalized_config import NormalizedConfigManager from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from .config import ( @@ -217,48 +218,9 @@ class OPTOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig -class LlamaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): - def __init__( - self, - task: str, - normalized_config: NormalizedTextConfig, - batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], - sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], - random_batch_size_range: Optional[Tuple[int, int]] = None, - random_sequence_length_range: Optional[Tuple[int, int]] = None, - **kwargs, - ): - super().__init__( - task=task, - normalized_config=normalized_config, - batch_size=batch_size, - sequence_length=sequence_length, - random_batch_size_range=random_batch_size_range, - random_sequence_length_range=random_sequence_length_range, - **kwargs, - ) - self.num_key_value_heads = normalized_config.num_key_value_heads - - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - shape = ( - self.batch_size, - self.num_key_value_heads, - self.sequence_length, - self.hidden_size // self.num_attention_heads, - ) - return [ - ( - self.random_float_tensor(shape, framework=framework, dtype=float_dtype), - self.random_float_tensor(shape, framework=framework, dtype=float_dtype), - ) - for _ in range(self.num_layers) - ] - - class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): - DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, LlamaDummyPastKeyValuesGenerator) - DUMMY_PKV_GENERATOR_CLASS = LlamaDummyPastKeyValuesGenerator - + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 419d3417c30..32377fb6060 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -217,7 +217,7 @@ def prepare_inputs_for_merged( if self.parent_model.use_merged and past_key_values is None: batch_size = input_ids.shape[0] - if self.normalized_config.config.model_type != "mistral": + if self.normalized_config.config.model_type in {"mistral", "llama"}: num_attention_heads = self.normalized_config.num_attention_heads else: num_attention_heads = self.normalized_config.num_key_value_heads @@ -281,7 +281,7 @@ def compute_past_key_values_output_shapes( `Dict[str, List[int]]`: The dictionary mapping each past key value output name to its corresponding shape. """ batch_size = input_ids.size(0) - if self.normalized_config.config.model_type != "mistral": + if self.normalized_config.config.model_type in {"mistral", "llama"}: num_attention_heads = self.normalized_config.num_attention_heads else: num_attention_heads = self.normalized_config.num_key_value_heads diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 1797028f33d..c444c913cac 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -922,8 +922,8 @@ def __init__( **kwargs, ): super().__init__( - task, - normalized_config, + task=task, + normalized_config=normalized_config, batch_size=batch_size, sequence_length=sequence_length, random_batch_size_range=random_batch_size_range,