Skip to content

Commit

Permalink
update for recent transformers updates
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 26, 2023
1 parent 5e5296a commit 93a277f
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand Down Expand Up @@ -476,6 +477,13 @@ def llama_model_forward(
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None

attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
Expand Down Expand Up @@ -510,7 +518,12 @@ def llama_model_forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return module(
*inputs,
past_key_value, # pylint: disable=(cell-var-from-loop)
output_attentions,
attention_mask=attention_mask,
)

return custom_forward

Expand All @@ -520,8 +533,6 @@ def custom_forward(*inputs):
attention_mask,
position_ids,
None,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
Expand All @@ -533,6 +544,7 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
Expand Down Expand Up @@ -579,6 +591,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
Expand Down Expand Up @@ -611,6 +624,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
Expand Down

0 comments on commit 93a277f

Please sign in to comment.