Skip to content

Commit

Permalink
use flash_attn xentropy when available
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 committed Sep 4, 2023
1 parent 09f1543 commit e37d3b2
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py

import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -33,6 +35,9 @@
)


LOG = logging.getLogger("axolotl")


def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
Expand All @@ -43,6 +48,15 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward
)
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss

LOG.info("patching with flash_attention.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
pass


# Disable the transformation of the attention mask in LlamaModel as the flash attention
Expand Down

0 comments on commit e37d3b2

Please sign in to comment.