Skip to content

Commit

Permalink
Fix: Mamba2 generation mismatch between input_ids and inputs_embeds (h…
Browse files Browse the repository at this point in the history
…uggingface#32694)

* fix cache when using input embeddings

* simplify check, we can always add input ids seq len since its 0 in first pass
  • Loading branch information
vasqu authored Aug 19, 2024
1 parent 93e538a commit 61d89c1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,8 @@ def prepare_inputs_for_generation(
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
if input_ids.shape[1] == 0:
past_len = inputs_embeds.shape[1]
if inputs_embeds is not None:
past_len = inputs_embeds.shape[1] + input_ids.shape[1]
else:
past_len = input_ids.shape[1]
if use_cache:
Expand Down

0 comments on commit 61d89c1

Please sign in to comment.