Skip to content

Commit

Permalink
update for recent transformers updates (#636)
Browse files Browse the repository at this point in the history
* update for recent transformers updates

* fix checkpoint forward kwargs

* just pass args into torch checkpoint
  • Loading branch information
winglian authored Sep 27, 2023
1 parent e8cbf50 commit 60c7c48
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand Down Expand Up @@ -476,6 +477,13 @@ def llama_model_forward(
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None

attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
Expand Down Expand Up @@ -510,7 +518,9 @@ def llama_model_forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return module(
*inputs,
)

return custom_forward

Expand All @@ -519,9 +529,10 @@ def custom_forward(*inputs):
hidden_states,
attention_mask,
position_ids,
None,
past_key_value,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
Expand All @@ -533,6 +544,7 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
Expand Down Expand Up @@ -579,6 +591,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
Expand Down Expand Up @@ -611,6 +624,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
Expand Down

0 comments on commit 60c7c48

Please sign in to comment.