From 1188299a773cea21d2b06742715f65fed9bd5322 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 27 Sep 2023 12:10:32 -0400 Subject: [PATCH] update for recent transformers updates (#636) * update for recent transformers updates * fix checkpoint forward kwargs * just pass args into torch checkpoint --- .../monkeypatch/llama_attn_hijack_flash.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d172d302d9..97f0477649 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -99,6 +99,7 @@ def flashattn_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -476,6 +477,13 @@ def llama_model_forward( dtype=torch.bool, device=inputs_embeds.device, ) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + attention_mask = ( self._prepare_decoder_attention_mask( # pylint: disable=protected-access attention_mask, @@ -510,7 +518,9 @@ def llama_model_forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs) + return module( + *inputs, + ) return custom_forward @@ -519,9 +529,10 @@ def custom_forward(*inputs): hidden_states, attention_mask, position_ids, - None, + past_key_value, output_attentions, None, + padding_mask, cu_seqlens, max_seqlen, ) @@ -533,6 +544,7 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) @@ -579,6 +591,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[ @@ -611,6 +624,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, )