From 93a277fc9bfb494b94fd60144ec5d0a71020c287 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 Sep 2023 08:45:12 -0400 Subject: [PATCH] update for recent transformers updates --- .../monkeypatch/llama_attn_hijack_flash.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d172d302d9..db2af54631 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -99,6 +99,7 @@ def flashattn_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -476,6 +477,13 @@ def llama_model_forward( dtype=torch.bool, device=inputs_embeds.device, ) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + attention_mask = ( self._prepare_decoder_attention_mask( # pylint: disable=protected-access attention_mask, @@ -510,7 +518,12 @@ def llama_model_forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs) + return module( + *inputs, + past_key_value, # pylint: disable=(cell-var-from-loop) + output_attentions, + attention_mask=attention_mask, + ) return custom_forward @@ -520,8 +533,6 @@ def custom_forward(*inputs): attention_mask, position_ids, None, - output_attentions, - None, cu_seqlens, max_seqlen, ) @@ -533,6 +544,7 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) @@ -579,6 +591,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[ @@ -611,6 +624,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, )