diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index cb0aa3fe6f..33de909719 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -2,7 +2,9 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py +import logging import warnings +from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -33,6 +35,9 @@ ) +LOG = logging.getLogger("axolotl") + + def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask @@ -44,6 +49,16 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): llama_model_forward ) + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss + + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) + except ImportError: + pass + # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask