Skip to content

Commit

Permalink
fix dynamic axis for position ids
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 19, 2023
1 parent 4b68caa commit e5fd9f8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
5 changes: 1 addition & 4 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if self.use_past_in_inputs:
common_inputs["position_ids"] = {0: "batch_size"}
else:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs

Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,8 @@ def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache_branch: None = None,
**kwargs,
Expand Down

0 comments on commit e5fd9f8

Please sign in to comment.