From cdeba0711b911b8cc81f45199bd2265b9eeb0f43 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Mon, 4 Sep 2023 04:04:34 +0000 Subject: [PATCH 1/2] use flash_attn xentropy when available --- .../monkeypatch/llama_attn_hijack_flash.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index cb0aa3fe6f..33de909719 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -2,7 +2,9 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py +import logging import warnings +from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -33,6 +35,9 @@ ) +LOG = logging.getLogger("axolotl") + + def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask @@ -44,6 +49,16 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): llama_model_forward ) + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss + + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) + except ImportError: + pass + # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask From abd4f9a10ed31d1664405318b7256bb54b6fdb22 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Mon, 4 Sep 2023 17:43:52 -0400 Subject: [PATCH 2/2] log when xentropy is not found --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 33de909719..b0163a6556 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -57,7 +57,9 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): CrossEntropyLoss, inplace_backward=True ) except ImportError: - pass + LOG.info( + "optimized flash-attention CrossEntropyLoss not found (run `pip install git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy`)" + ) # Disable the transformation of the attention mask in LlamaModel as the flash attention