Skip to content

Commit

Permalink
[GPTNeoX] Flex Attention + Refactor (#34896)
Browse files Browse the repository at this point in the history
* gpt neox flex attention + refactor

* some formatting

* small fix on dropout

* add assertion on flex attn test

* flaky ci :(

* add head mask support

* style

* handle dtype, replace torch where

* fixup flex with output attns

* code review and several other fixes

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* style

* remove unnecessary comment

* remove incorrect comment

* make flex attn check more agnostic tor versions and centralized

* change peft input dtype check to value since q and k could be affected by other stuff like RoPE

* i forgor

* flaky

* code review and small fixes

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
vasqu and ArthurZucker authored Dec 4, 2024
1 parent accb720 commit 46df859
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 249 deletions.
52 changes: 51 additions & 1 deletion src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import torch
import torch.nn.functional as F

from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging


logger = logging.get_logger(__name__)


if is_flash_attn_2_available():
Expand Down Expand Up @@ -180,6 +183,47 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))


def fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
"""
if target_dtype is None:
return query, key, value

input_dtype = value.dtype
if input_dtype == torch.float32:
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)

return query, key, value


flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"

Expand All @@ -202,6 +246,7 @@ def _flash_attention_forward(
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
target_dtype: Optional[torch.dtype] = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand Down Expand Up @@ -248,6 +293,11 @@ def _flash_attention_forward(
if softcap is not None:
flash_kwargs["softcap"] = softcap

# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)

# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
Expand Down
41 changes: 40 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
is_peft_available,
is_remote_url,
is_safetensors_available,
is_torch_flex_attn_available,
is_torch_greater_or_equal,
is_torch_sdpa_available,
is_torch_xla_available,
Expand Down Expand Up @@ -1342,6 +1343,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# SDPA support
_supports_sdpa = False

# Flex Attention support
_supports_flex_attn = False

# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False
_supports_static_cache = False
Expand Down Expand Up @@ -1548,6 +1552,10 @@ def _autoset_attn_implementation(
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += (
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
)
raise ValueError(message + ".")

# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
Expand Down Expand Up @@ -1582,6 +1590,8 @@ def _autoset_attn_implementation(
hard_check_only=False,
check_device_map=check_device_map,
)
elif requested_attn_implementation == "flex_attention":
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
Expand Down Expand Up @@ -1778,7 +1788,7 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra
"""
Checks the availability of SDPA for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_sdpa:
Expand All @@ -1803,6 +1813,35 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra
config._attn_implementation = "sdpa"
return config

@classmethod
def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
Checks the availability of Flex Attention for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_flex_attn:
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch's flex_attention."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
if not is_torch_flex_attn_available():
raise ImportError(
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
)

if not is_torch_flex_attn_available() or not cls._supports_flex_attn:
return config

if not hard_check_only:
config._attn_implementation = "flex_attention"

return config

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
Expand Down
Loading

0 comments on commit 46df859

Please sign in to comment.