Skip to content

Commit

Permalink
updated decoder_sequence_name to "past_sequence_length + sequence_len…
Browse files Browse the repository at this point in the history
…gth"

instead of "past_sequence_length + 1" since the output length depends on
the input sequence_length which may not be 1, e.g. when filling the kv-cache
  • Loading branch information
PatrikPerssonInceptron committed Nov 4, 2024
1 parent 7e8d857 commit 0badc67
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/exporters/onnx/usage_guides/export_a_model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class CustomMPTOnnxConfig(TextDecoderOnnxConfig):
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
decoder_sequence_name = "past_sequence_length + sequence_length"
name = "present"

for i in range(self._normalized_config.num_layers):
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
decoder_sequence_name = "past_sequence_length + sequence_length"
name = "present"

for i in range(self._normalized_config.num_layers):
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}
else:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
Expand Down
8 changes: 4 additions & 4 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
decoder_sequence_name = "past_sequence_length + sequence_length"
name = "present"

for i in range(self._normalized_config.num_layers):
Expand Down Expand Up @@ -403,7 +403,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
decoder_sequence_name = "past_sequence_length + sequence_length"
name = "present"

for i in range(self._normalized_config.num_layers):
Expand Down Expand Up @@ -638,7 +638,7 @@ def inputs_for_causal_lm(self):
if self.use_past_in_inputs:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "past_sequence_length + 1"},
"attention_mask": {0: "batch_size", 1: "past_sequence_length + sequence_length"},
}
for i in range(self._normalized_config.decoder_num_layers):
common_inputs[f"past_key_values.{i}.key"] = {
Expand Down Expand Up @@ -2216,7 +2216,7 @@ def inputs(self):
common_inputs["encoder_outputs"] = {0: "batch_size"}

# Contrary to other seq2seq archs as t5 and bart, Pix2Struct DO make use of the decoder_attention_mask input.
common_inputs["decoder_attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
common_inputs["decoder_attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}

return common_inputs

Expand Down
2 changes: 1 addition & 1 deletion tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
decoder_sequence_name = "past_sequence_length + sequence_length"
name = "present"

for i in range(self._normalized_config.num_layers):
Expand Down

0 comments on commit 0badc67

Please sign in to comment.