Skip to content

Commit

Permalink
skip some flash attn patches unless explicitly enabled (#643)
Browse files Browse the repository at this point in the history
* skip some flash attn patches if explicitly disabled

* make the other patches optional
  • Loading branch information
winglian authored Sep 27, 2023
1 parent e7d3e2d commit 895f0a0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 24 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,8 @@ flash_optimum:
xformers_attention:
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
Expand Down
54 changes: 31 additions & 23 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
LOG = logging.getLogger("axolotl")


def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
Expand All @@ -49,33 +53,37 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
llama_model_forward
)

try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
# skip only if explicitly disabled
if cross_entropy:
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss

LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)

try:
from flash_attn.ops.rms_norm import RMSNorm
# skip only if explicitly disabled
if rms_norm:
try:
from flash_attn.ops.rms_norm import RMSNorm

class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""

def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)

LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)


# Disable the transformation of the attention mask in LlamaModel as the flash attention
Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ def load_model(
)

LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
replace_llama_attn_with_flash_attn(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
Expand Down

0 comments on commit 895f0a0

Please sign in to comment.