diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 77b6e2422cef..e48411217717 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -364,8 +364,9 @@ def _set_cos_sin_cache(self, seq_len): self.sin_cached = emb.sin()[None, :, None, :] def forward(self, x, seq_len=None): - cos = self.cos_cached[:, :seq_len, :, ...] - sin = self.sin_cached[:, :seq_len, :, ...] + # x: [bs, num_attention_heads, seq_len, head_size] + cos = self.cos_cached[:, :, :seq_len, ...] + sin = self.sin_cached[:, :, :seq_len, ...] return ( cos.cast(x.dtype) if cos.dtype != x.dtype else cos, sin.cast(x.dtype) if sin.dtype != x.dtype else sin,