diff --git a/README.md b/README.md index 3b3bc9985a..7f3230423c 100644 --- a/README.md +++ b/README.md @@ -636,6 +636,8 @@ flash_optimum: xformers_attention: # whether to use flash attention patch https://github.com/Dao-AILab/flash-attention: flash_attention: +flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only +flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only # whether to use scaled-dot-product attention # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html sdp_attention: diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 97f0477649..4f6b715756 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -38,7 +38,11 @@ LOG = logging.getLogger("axolotl") -def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): +def replace_llama_attn_with_flash_attn( + packed: Optional[bool] = False, + cross_entropy: Optional[bool] = False, + rms_norm: Optional[bool] = False, +): transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask ) @@ -49,33 +53,37 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): llama_model_forward ) - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss + # skip only if explicitly disabled + if cross_entropy: + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss - LOG.info("patching with flash_attn.losses.cross_entropy") - transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( - CrossEntropyLoss, inplace_backward=True - ) - except ImportError: - LOG.info( - "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" - ) + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) + except ImportError: + LOG.info( + "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" + ) - try: - from flash_attn.ops.rms_norm import RMSNorm + # skip only if explicitly disabled + if rms_norm: + try: + from flash_attn.ops.rms_norm import RMSNorm - class LlamaRMSNorm(RMSNorm): - """Patched LLamaRMSNorm""" + class LlamaRMSNorm(RMSNorm): + """Patched LLamaRMSNorm""" - def __init__(self, hidden_size, eps=1e-6): - super().__init__(hidden_size, eps=eps) + def __init__(self, hidden_size, eps=1e-6): + super().__init__(hidden_size, eps=eps) - LOG.info("patching with flash_attn.ops.rms_norm") - transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm - except ImportError: - LOG.info( - "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" - ) + LOG.info("patching with flash_attn.ops.rms_norm") + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + except ImportError: + LOG.info( + "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" + ) # Disable the transformation of the attention mask in LlamaModel as the flash attention diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 361440931f..07cdc4d6ed 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -121,7 +121,11 @@ def load_model( ) LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn(packed=cfg.sample_packing) + replace_llama_attn_with_flash_attn( + packed=cfg.sample_packing, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + ) elif cfg.is_llama_derived_model and cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention,