Skip to content

Commit

Permalink
revert llama code
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 25, 2023
1 parent c6338d8 commit 4c5dadf
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,9 @@ def _set_cos_sin_cache(self, seq_len):
self.sin_cached = emb.sin()[None, :, None, :]

def forward(self, x, seq_len=None):
cos = self.cos_cached[:, :seq_len, :, ...]
sin = self.sin_cached[:, :seq_len, :, ...]
# x: [bs, num_attention_heads, seq_len, head_size]
cos = self.cos_cached[:, :, :seq_len, ...]
sin = self.sin_cached[:, :, :seq_len, ...]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
Expand Down

0 comments on commit 4c5dadf

Please sign in to comment.