From 46df859975404e475cf5eeae76634f69951abe44 Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Wed, 4 Dec 2024 14:48:28 +0100 Subject: [PATCH] [`GPTNeoX`] Flex Attention + Refactor (#34896) * gpt neox flex attention + refactor * some formatting * small fix on dropout * add assertion on flex attn test * flaky ci :( * add head mask support * style * handle dtype, replace torch where * fixup flex with output attns * code review and several other fixes * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style * remove unnecessary comment * remove incorrect comment * make flex attn check more agnostic tor versions and centralized * change peft input dtype check to value since q and k could be affected by other stuff like RoPE * i forgor * flaky * code review and small fixes * Update src/transformers/models/gpt_neox/modeling_gpt_neox.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../modeling_flash_attention_utils.py | 52 +- src/transformers/modeling_utils.py | 41 +- .../models/gpt_neox/modeling_gpt_neox.py | 489 +++++++++--------- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 12 + .../models/gpt_neox/test_modeling_gpt_neox.py | 25 + 6 files changed, 371 insertions(+), 249 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 1b9274e21f5205..ec03ba1eb5fd83 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -20,7 +20,10 @@ import torch import torch.nn.functional as F -from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal +from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging + + +logger = logging.get_logger(__name__) if is_flash_attn_2_available(): @@ -180,6 +183,47 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) +def fa_peft_integration_check( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + target_dtype: Optional[torch.dtype] = None, +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + + Args: + query (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value (`torch.Tensor`): + Input value states to be passed to Flash Attention API + target_dtype (`torch.dtype`, *optional*): + The dtype to convert the attention tensors to. Conversion can be ignored by + not providing the target dtype. + """ + if target_dtype is None: + return query, key, value + + input_dtype = value.dtype + if input_dtype == torch.float32: + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + return query, key, value + + flash_241 = is_flash_attn_greater_or_equal("2.4.1") deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" @@ -202,6 +246,7 @@ def _flash_attention_forward( cu_seq_lens_k: Optional[torch.LongTensor] = None, max_length_q: Optional[int] = None, max_length_k: Optional[int] = None, + target_dtype: Optional[torch.dtype] = None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -248,6 +293,11 @@ def _flash_attention_forward( if softcap is not None: flash_kwargs["softcap"] = softcap + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op + query_states, key_states, value_states = fa_peft_integration_check( + query_states, key_states, value_states, target_dtype + ) + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 50622c9f55145a..dae29111c8dcc0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -89,6 +89,7 @@ is_peft_available, is_remote_url, is_safetensors_available, + is_torch_flex_attn_available, is_torch_greater_or_equal, is_torch_sdpa_available, is_torch_xla_available, @@ -1342,6 +1343,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # SDPA support _supports_sdpa = False + # Flex Attention support + _supports_flex_attn = False + # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? _supports_cache_class = False _supports_static_cache = False @@ -1548,6 +1552,10 @@ def _autoset_attn_implementation( message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' if cls._supports_sdpa: message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + if cls._supports_flex_attn: + message += ( + ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' + ) raise ValueError(message + ".") # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. @@ -1582,6 +1590,8 @@ def _autoset_attn_implementation( hard_check_only=False, check_device_map=check_device_map, ) + elif requested_attn_implementation == "flex_attention": + config = cls._check_and_enable_flex_attn(config, hard_check_only=True) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. config = cls._check_and_enable_sdpa( @@ -1778,7 +1788,7 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra """ Checks the availability of SDPA for a given model. - If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module. """ if hard_check_only: if not cls._supports_sdpa: @@ -1803,6 +1813,35 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra config._attn_implementation = "sdpa" return config + @classmethod + def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + """ + Checks the availability of Flex Attention for a given model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module. + """ + if hard_check_only: + if not cls._supports_flex_attn: + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch's flex_attention." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809." + " If you believe this error is a bug, please open an issue in Transformers GitHub repository" + ' and load your model with the argument `attn_implementation="eager"` meanwhile.' + ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_torch_flex_attn_available(): + raise ImportError( + "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0." + ) + + if not is_torch_flex_attn_available() or not cls._supports_flex_attn: + return config + + if not hard_check_only: + config._attn_implementation = "flex_attention" + + return config + def enable_input_require_grads(self): """ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 359996983eed74..3fdb814ebab51a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -18,7 +18,6 @@ import torch import torch.utils.checkpoint -from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -42,9 +41,9 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( - get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, ) from .configuration_gpt_neox import GPTNeoXConfig @@ -53,6 +52,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import flex_attention + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" @@ -76,6 +78,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _supports_quantized_cache = True _supports_static_cache = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" @@ -92,6 +95,169 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) +def eager_attention_forward( + query, key, value, attention_mask, head_mask, norm_factor, attention_dropout, training, **_kwargs +): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=norm_factor, + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_scores = attn_scores + causal_mask + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training) + attn_output = torch.matmul(attn_weights, value) + + # Reshape outputs + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def flash_attention_forward( + query, + key, + value, + attention_mask, + norm_factor, + attention_dropout, + training, + target_dtype=None, + **_kwargs, +): + query_length = query.shape[-2] + + # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision + query = query.to(value.dtype) + key = key.to(value.dtype) + + # Permute to get the expected shape for Flash Attention + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attention_dropout = attention_dropout if training else 0.0 + flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # Compute attention + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attention_dropout, + softmax_scale=norm_factor, + is_causal=True, + use_top_left_mask=flash_attn_uses_top_left_mask, + target_dtype=target_dtype, + ) + + return attn_output, None + + +def sdpa_attention_forward(query, key, value, attention_mask, attention_dropout, training, **_kwargs): + q_len = query.shape[-2] + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision + query = query.to(value.dtype) + key = key.to(value.dtype) + + # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=causal_mask, + dropout_p=attention_dropout if training else 0.0, + is_causal=is_causal, + ) + + # Reshape outputs + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None + + +def flex_attention_forward(query, key, value, attention_mask, head_mask, norm_factor, **_kwargs): + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def causal_mod(score, b, h, q_idx, kv_idx): + if causal_mask is not None: + score += causal_mask[b][0][q_idx][kv_idx] + if head_mask is not None: + score += head_mask[b][h][0][0] + return score + + attn_output, attn_weights = flex_attention( + query, + key, + value, + score_mod=causal_mod, + enable_gqa=True, + scale=norm_factor, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + return_lse=True, + ) + + # lse is returned in float32 + attn_weights = attn_weights.to(value.dtype) + + # Reshape outputs + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +GPTNEOX_ATTENTION_FUNCTION = { + "eager": eager_attention_forward, + "flash_attention_2": flash_attention_forward, + "sdpa": sdpa_attention_forward, + "flex_attention": flex_attention_forward, +} + + class GPTNeoXAttention(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() @@ -147,20 +313,57 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ): + bsz, seq_len, _ = hidden_states.shape + # Apply attention-specific projections and rope query, key, value, present = self._attn_projections_and_rope( hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache, + cache_position=cache_position, position_embeddings=position_embeddings, ) + # Checking for fallbacks in case an unsupported feature is requested + attention_type = self.config._attn_implementation + if (output_attentions or head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + ]: + logger.warning_once( + f"Setting `attention_type` to `eager` because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." + ) + attention_type = "eager" + + elif ( + self.training + and self.config.attention_dropout > 0 + and self.config._attn_implementation == "flex_attention" + ): + logger.warning_once( + f"Setting `attention_type` to `eager` because `dropout` is not supported in `{attention_type}`." + ) + attention_type = "eager" + # Compute attention - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = GPTNEOX_ATTENTION_FUNCTION[attention_type]( + query, + key, + value, + attention_mask=attention_mask, + head_mask=head_mask, + norm_factor=self.norm_factor, + attention_dropout=self.config.attention_dropout, + training=self.training, + # Flash Attention 2 specific PEFT check + target_dtype=self._fa_peft_dtype_check(value), + ) - # Reshape outputs - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + # Reshape outputs and final projection + attn_output = attn_output.contiguous() + attn_output = attn_output.view(bsz, seq_len, -1) attn_output = self.dense(attn_output) outputs = (attn_output, present) @@ -250,262 +453,47 @@ def _attn_projections_and_rope( return query, key, value, layer_past - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] - # compute causal mask from causal mask buffer - batch_size, num_attention_heads, query_length, attn_head_size = query.size() - key_length = key.size(-2) - - # dynamically increase the causal mask with the key length, if needed. - if key_length > self.bias.shape[-1]: - self._init_bias(key_length, device=key.device) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - - query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) - key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) - attn_scores = torch.zeros( - batch_size * num_attention_heads, - query_length, - key_length, - dtype=query.dtype, - device=key.device, - ) - attn_scores = torch.baddbmm( - attn_scores, - query, - key.transpose(1, 2), - beta=1.0, - alpha=self.norm_factor, - ) - attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) - - mask_value = torch.finfo(attn_scores.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) - attn_scores = torch.where(causal_mask, attn_scores, mask_value) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_scores = attn_scores + causal_mask - - attn_weights = nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(value.dtype) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_weights = self.attention_dropout(attn_weights) - - attn_output = torch.matmul(attn_weights, value) - return attn_output, attn_weights + def _fa_peft_dtype_check(self, value): + """ + PEFT can silently cast the dtype to float32 - this method returns the target dtype to which + FA should convert back to (if necessary). For now, we can not move this to the forward pass + itself due to the dependency on checking on some part of its own weights (last case). + """ + target_dtype = None + if self.config._attn_implementation == "flash_attention_2": + input_dtype = value.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query_key_value.weight.dtype + return target_dtype +# TODO Remove in deprecation cycle class GPTNeoXFlashAttention2(GPTNeoXAttention): - """ - GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - position_ids: torch.LongTensor, - head_mask: Optional[torch.FloatTensor] = None, - layer_past: Optional[Cache] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ): - # Apply attention-specific projections and rope - query, key, value, present = self._attn_projections_and_rope( - hidden_states=hidden_states, - position_ids=position_ids, - layer_past=layer_past, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - query_length = query.shape[-2] - - # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision - target_dtype = value.dtype - if query.dtype != target_dtype: - query = query.to(target_dtype) - if key.dtype != target_dtype: - key = key.to(target_dtype) - - # Permute to get the expected shape for Flash Attention - query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 / bfloat16 just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - input_dtype = query.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.query_key_value.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) - - attention_dropout = self.config.attention_dropout if self.training else 0.0 - - # Compute attention - attn_weights = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attention_dropout, - softmax_scale=self.norm_factor, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - # Reshape outputs - attn_output = attn_weights.reshape( - attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size + logger.warning_once( + "The `GPTNeoXFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GPTNeoXAttention` class! It will be removed in v4.48" ) - attn_output = self.dense(attn_output) - - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs +# TODO Remove in deprecation cycle class GPTNeoXSdpaAttention(GPTNeoXAttention): - """ - GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GPTNeoXAttention` as the weights of the module stays untouched. The only changes are on the forward pass - to adapt to the SDPA API. - """ - def __init__(self, config, layer_idx=None): super().__init__(config, layer_idx=layer_idx) - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - position_ids: torch.LongTensor, - head_mask: Optional[torch.FloatTensor] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ): - if output_attentions or head_mask is not None: - logger.warning_once( - "`GPTNeoXSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - # Apply attention-specific projections and rope - query, key, value, present = self._attn_projections_and_rope( - hidden_states=hidden_states, - position_ids=position_ids, - layer_past=layer_past, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision - target_dtype = value.dtype - if query.dtype != target_dtype: - query = query.to(target_dtype) - if key.dtype != target_dtype: - key = key.to(target_dtype) - - # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA - if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=causal_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=is_causal, + logger.warning_once( + "The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`" + "attribute of the `GPTNeoXAttention` class! It will be removed in v4.48" ) - # Reshape outputs - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.dense(attn_output) - - return attn_output, present, None - - -def attention_mask_func(attention_scores, ltor_mask): - attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min) - return attention_scores - # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoX class GPTNeoXRotaryEmbedding(nn.Module): @@ -675,6 +663,7 @@ def forward(self, hidden_states): "eager": GPTNeoXAttention, "flash_attention_2": GPTNeoXFlashAttention2, "sdpa": GPTNeoXSdpaAttention, + "flex_attention": GPTNeoXAttention, } @@ -919,7 +908,13 @@ def forward( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + converted_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + # Flex Attention converts it to a separate mask + if head_mask is not None: + converted_head_mask = ~converted_head_mask.bool() * torch.finfo(inputs_embeds.dtype).min + converted_head_mask = converted_head_mask.to(dtype=self.dtype, device=self.device) + head_mask = converted_head_mask + hidden_states = self.emb_dropout(inputs_embeds) # create position embeddings to be shared across the decoder layers diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 492642d61babb5..f7e962bec346fb 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -206,6 +206,7 @@ is_torch_compile_available, is_torch_cuda_available, is_torch_deterministic, + is_torch_flex_attn_available, is_torch_fp16_available_on_device, is_torch_fx_available, is_torch_fx_proxy, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1a33de63351558..2ce4bd7bc778da 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -358,6 +358,17 @@ def is_torch_sdpa_available(): return version.parse(_torch_version) >= version.parse("2.1.1") +def is_torch_flex_attn_available(): + if not is_torch_available(): + return False + elif _torch_version == "N/A": + return False + + # TODO check if some bugs cause push backs on the exact version + # NOTE: We require torch>=2.5.0 as it is the first release + return version.parse(_torch_version) >= version.parse("2.5.0") + + def is_torchvision_available(): return _torchvision_available @@ -916,6 +927,7 @@ def is_flash_attn_2_available(): return False +@lru_cache() def is_flash_attn_greater_or_equal_2_10(): if not _is_package_available("flash_attn"): return False diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 2c3319f02475cc..435133e93860ac 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -459,6 +459,31 @@ def test_lm_generate_gptneox(self): self.assertEqual(output_str, expected_output) + @slow + def test_lm_generate_flex_attn_gptneox(self): + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped") + for checkpointing in [True, False]: + model = GPTNeoXForCausalLM.from_pretrained( + "EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention" + ) + self.assertTrue(model.config._attn_implementation == "flex_attention") + + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() + model.to(torch_device) + + inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device) + # The hub repo. is updated on 2023-04-04, resulting in poor outputs. + # See: https://github.com/huggingface/transformers/pull/24193 + expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure" + + output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20) + output_str = tokenizer.batch_decode(output_ids)[0] + + self.assertEqual(output_str, expected_output) + def pythia_integration_test(self): model_name_or_path = "EleutherAI/pythia-70m" model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)