Skip to content

Commit

Permalink
fix after #1432 merged
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 6, 2023
1 parent aca22df commit c3b584a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 45 deletions.
44 changes: 3 additions & 41 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
NormalizedVisionConfig,
logging,
)

from ...utils.normalized_config import NormalizedConfigManager
from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .config import (
Expand Down Expand Up @@ -217,48 +218,9 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class LlamaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
**kwargs,
)
self.num_key_value_heads = normalized_config.num_key_value_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]


class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, LlamaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = LlamaDummyPastKeyValuesGenerator

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

Expand Down
4 changes: 2 additions & 2 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def prepare_inputs_for_merged(
if self.parent_model.use_merged and past_key_values is None:
batch_size = input_ids.shape[0]

if self.normalized_config.config.model_type != "mistral":
if self.normalized_config.config.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_attention_heads
else:
num_attention_heads = self.normalized_config.num_key_value_heads
Expand Down Expand Up @@ -281,7 +281,7 @@ def compute_past_key_values_output_shapes(
`Dict[str, List[int]]`: The dictionary mapping each past key value output name to its corresponding shape.
"""
batch_size = input_ids.size(0)
if self.normalized_config.config.model_type != "mistral":
if self.normalized_config.config.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_attention_heads
else:
num_attention_heads = self.normalized_config.num_key_value_heads
Expand Down
4 changes: 2 additions & 2 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,8 @@ def __init__(
**kwargs,
):
super().__init__(
task,
normalized_config,
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
Expand Down

0 comments on commit c3b584a

Please sign in to comment.