From 75be5a0a5b1898ee86e5e0c1f7b58b77bb105101 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:38:19 +0100 Subject: [PATCH 01/23] [Whisper] fix docstrings typo (#35319) typos docstring --- .../models/whisper/generation_whisper.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 6b71671e14c852..360c0c0b687bab 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -382,7 +382,7 @@ def generate( the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details. - generation_config (`~generation.GenerationConfig`, *optional*): + generation_config ([`~generation.GenerationConfig`], *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading @@ -480,8 +480,8 @@ def generate( `return_segments` is set True. In this case the generation outputs of each segment is added to each segment. force_unique_generate_call (`bool`, *optional*): - Whether to force a unique call to the underlying GenerationMixin's generate method. This is useful for assisted decoding and testing purposes to ensure - that only one call to generate is made and therefore decoder input token ids and eos token ids are returned. + Whether to force a unique call to the underlying GenerationMixin's [~generation.GenerationMixin.generate] method. This is useful for assisted decoding and testing purposes to ensure + that only one call to [~generation.GenerationMixin.generate] is made and therefore decoder input token ids and eos token ids are returned. kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -495,18 +495,18 @@ def generate( - `torch.LongTensor` in all other cases, excluding the decoder input ids and end of sequence id. The possible [`~utils.ModelOutput`] types are: - - [`~utils.GenerateEncoderDecoderOutput`] - - [`~utils.GenerateBeamEncoderDecoderOutput`] + - [`~generation.GenerateEncoderDecoderOutput`] + - [`~generation.GenerateBeamEncoderDecoderOutput`] `segments` is a list of lists (one list per batch element) of `segment`. A `segment` is a dictionary with keys `start`, `end`, `tokens`, `idxs`, and `result`. - `start`: the start timestamp of the segment. - `end`: the end timestamp of the segment. - `tokens`: the tokens of the segment, excluding the decoder input ids and end of sequence id. - - `idxs`: the start (included) and end (excluded) indices of the `tokens` of the segment in the underlying call to GenerationMixin's `generate` (present in `result`). - - `result`: the result of the underlying call to GenerationMixin's `generate`. + - `idxs`: the start (included) and end (excluded) indices of the `tokens` of the segment in the underlying call to GenerationMixin's [~generation.GenerationMixin.generate] (present in `result`). + - `result`: the result of the underlying call to GenerationMixin's [~generation.GenerationMixin.generate]. - When `return_timestamps=True`, `return_dict_in_generate=True` applies to each call of the underlying GenerationMixin's `generate`, with outputs stored in `result` of each `segment`. + When `return_timestamps=True`, `return_dict_in_generate=True` applies to each call of the underlying GenerationMixin's [~generation.GenerationMixin.generate], with outputs stored in `result` of each `segment`. Example: @@ -543,7 +543,7 @@ def generate( ``` - *Shortform transcription*: If passed mel input features are <= 30 seconds, there are two possibilities: - - `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's generate. + - `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's [~generation.GenerationMixin.generate]. - `return_timestamps=True`: the audio will be transcribed using the same logic as long-form transcription. ```python From 2c47618c1a282f925446506d53108dc6e82d9ef0 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:53:39 +0100 Subject: [PATCH 02/23] =?UTF-8?q?=F0=9F=9A=A8All=20attention=20refactor?= =?UTF-8?q?=F0=9F=9A=A8=20(#35235)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor LlamaAttention * minimal changes * fix llama * update * modular gemmas * modular nits * modular updates * nits * simplify * gpt2 * more modualr and fixes * granite * modular modular modular * nits * update * qwen2 + starcoder2 * mostly gemma2 * Update image_processing_auto.py * fix * Update modular_starcoder2.py * fix * remove all copied from attentions * remove gcv * make fix-copies * oups * oups2.0 * fix some modulars + all copied from * should be good now * revert unwanted changes * Update modeling_decision_transformer.py * finish cleanup * Update modeling_olmo.py * consistency * re-add gradient checkpointing attribute * fix * style * make config necessary * bis * bis * Update modeling_my_new_model2.py * is_causal attr * fix * remove past kv return from decoder layer * fix * default rope config * correctly fix rope config * fix bias * fix gpt2 attention output * fix test * fix inits * fix default sdpa * fix default sdpa implementation * harmonize classes * fix mistral * fix sliding window models * mixtral * be more explicit * style * fix * several fixes * Update modeling_dbrx.py * fix test * olmo + phi * rotary * syle * phi * phi again * again * kwargs * Update test_modeling_common.py * skip fx tracing tests * Update modeling_utils.py * gemma 2 * again * Update modeling_recurrent_gemma.py * gemma2 * granite * style * starcoder * Update sdpa_attention.py * switch args * Update modeling_mllama.py * fix * cache type tests * gpt2 * Update test_modeling_common.py * fix * consistency * fix shape with encoder * should be the last one * tests non model * most comments * small oupsi * be more explicit in modulars * more explicit modulars * CIs! it works locally * add kwargs to _flash_attention_forward --------- Co-authored-by: Cyril Vallez --- .../modular-transformers/modeling_dummy.py | 445 ++------- .../modeling_multimodal1.py | 447 ++------- .../modeling_my_new_model2.py | 519 +++------- .../modeling_new_task_model.py | 34 +- .../modular-transformers/modeling_super.py | 411 ++------ src/transformers/configuration_utils.py | 2 +- .../integrations/flash_attention.py | 52 + .../integrations/flex_attention.py | 44 + .../integrations/sdpa_attention.py | 55 + .../modeling_flash_attention_utils.py | 3 +- src/transformers/modeling_utils.py | 16 +- src/transformers/models/aria/modeling_aria.py | 438 ++------ src/transformers/models/bark/modeling_bark.py | 1 - src/transformers/models/bart/modeling_bart.py | 1 - .../models/chameleon/modeling_chameleon.py | 2 +- src/transformers/models/clip/modeling_clip.py | 1 - .../models/cohere/modeling_cohere.py | 38 +- .../models/cohere2/modeling_cohere2.py | 5 +- .../data2vec/modeling_data2vec_audio.py | 1 - src/transformers/models/dbrx/modeling_dbrx.py | 2 - .../modeling_decision_transformer.py | 164 +-- .../models/distilbert/modeling_distilbert.py | 1 - .../models/falcon/modeling_falcon.py | 37 +- .../models/gemma/modeling_gemma.py | 549 +++------- .../models/gemma/modular_gemma.py | 598 +---------- .../models/gemma2/modeling_gemma2.py | 437 +++----- .../models/gemma2/modular_gemma2.py | 359 ++----- src/transformers/models/glm/modeling_glm.py | 537 +++------- src/transformers/models/glm/modular_glm.py | 102 +- src/transformers/models/gpt2/modeling_gpt2.py | 399 ++------ .../gpt_bigcode/modeling_gpt_bigcode.py | 1 - .../models/gpt_neo/modeling_gpt_neo.py | 1 - .../models/gpt_neox/modeling_gpt_neox.py | 36 +- .../modeling_gpt_neox_japanese.py | 36 +- src/transformers/models/gptj/modeling_gptj.py | 1 - .../models/granite/modeling_granite.py | 685 ++++--------- .../models/granite/modular_granite.py | 291 ++++++ .../models/granitemoe/modeling_granitemoe.py | 21 +- .../models/hubert/modeling_hubert.py | 1 - .../models/idefics/modeling_idefics.py | 1 - .../models/idefics2/modeling_idefics2.py | 5 +- .../models/idefics3/modeling_idefics3.py | 1 - .../models/jamba/modeling_jamba.py | 7 +- .../models/jetmoe/modeling_jetmoe.py | 66 +- .../models/llama/modeling_llama.py | 439 ++------ .../models/m2m_100/modeling_m2m_100.py | 1 - .../models/mbart/modeling_mbart.py | 1 - src/transformers/models/mimi/modeling_mimi.py | 72 +- .../models/mistral/modeling_mistral.py | 788 +++++---------- .../models/mistral/modular_mistral.py | 350 +++++++ .../models/mixtral/modeling_mixtral.py | 942 +++++++----------- .../models/mixtral/modular_mixtral.py | 574 +++++++++++ .../models/mllama/modeling_mllama.py | 3 +- .../models/moshi/modeling_moshi.py | 72 +- .../models/musicgen/modeling_musicgen.py | 1 - .../modeling_musicgen_melody.py | 1 - .../models/nemotron/modeling_nemotron.py | 9 +- src/transformers/models/olmo/modeling_olmo.py | 671 ++++--------- src/transformers/models/olmo/modular_olmo.py | 126 +++ .../models/olmo2/configuration_olmo2.py | 1 + .../models/olmo2/modeling_olmo2.py | 590 ++++------- .../models/olmo2/modular_olmo2.py | 266 +---- .../models/olmoe/modeling_olmoe.py | 40 +- src/transformers/models/opt/modeling_opt.py | 1 - .../models/persimmon/modeling_persimmon.py | 36 +- src/transformers/models/phi/modeling_phi.py | 826 +++++---------- src/transformers/models/phi/modular_phi.py | 295 ++++++ src/transformers/models/phi3/modeling_phi3.py | 8 +- .../models/phimoe/modeling_phimoe.py | 5 +- .../models/pixtral/modeling_pixtral.py | 6 +- .../models/qwen2/modeling_qwen2.py | 793 +++++---------- .../models/qwen2/modular_qwen2.py | 134 +++ .../qwen2_audio/modeling_qwen2_audio.py | 1 - .../models/qwen2_moe/modeling_qwen2_moe.py | 49 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 6 +- .../modeling_recurrent_gemma.py | 3 +- src/transformers/models/sew/modeling_sew.py | 1 - .../models/siglip/modeling_siglip.py | 1 - .../models/stablelm/modeling_stablelm.py | 43 +- .../models/starcoder2/modeling_starcoder2.py | 645 ++++-------- .../models/starcoder2/modular_starcoder2.py | 457 ++------- .../models/unispeech/modeling_unispeech.py | 1 - .../unispeech_sat/modeling_unispeech_sat.py | 1 - .../models/wav2vec2/modeling_wav2vec2.py | 1 - .../models/whisper/modeling_whisper.py | 1 - .../models/zamba/modeling_zamba.py | 7 +- .../test_modeling_encoder_decoder.py | 9 - tests/models/falcon/test_modeling_falcon.py | 28 +- tests/models/gpt2/test_modeling_gpt2.py | 2 +- .../models/gpt_neox/test_modeling_gpt_neox.py | 28 +- .../models/idefics2/test_modeling_idefics2.py | 9 - tests/models/llama/test_modeling_llama.py | 6 +- tests/models/mistral/test_modeling_mistral.py | 2 +- tests/models/mixtral/test_modeling_mixtral.py | 2 +- .../persimmon/test_modeling_persimmon.py | 29 +- tests/models/phi/test_modeling_phi.py | 29 +- tests/models/qwen2/test_modeling_qwen2.py | 2 +- .../qwen2_audio/test_modeling_qwen2_audio.py | 9 - .../qwen2_moe/test_modeling_qwen2_moe.py | 2 +- .../test_modeling_speech_encoder_decoder.py | 9 - .../models/stablelm/test_modeling_stablelm.py | 29 +- .../test_modeling_vision_encoder_decoder.py | 9 - tests/test_modeling_common.py | 38 +- tests/test_modeling_flax_common.py | 6 +- tests/test_modeling_tf_common.py | 5 +- tests/utils/test_modeling_utils.py | 35 - utils/check_config_attributes.py | 4 + 107 files changed, 5635 insertions(+), 9778 deletions(-) create mode 100644 src/transformers/integrations/flash_attention.py create mode 100644 src/transformers/integrations/flex_attention.py create mode 100644 src/transformers/integrations/sdpa_attention.py create mode 100644 src/transformers/models/granite/modular_granite.py create mode 100644 src/transformers/models/mistral/modular_mistral.py create mode 100644 src/transformers/models/mixtral/modular_mixtral.py create mode 100644 src/transformers/models/olmo/modular_olmo.py create mode 100644 src/transformers/models/phi/modular_phi.py create mode 100644 src/transformers/models/qwen2/modular_qwen2.py diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 6172c9acfd2114..3e0aa6e9b2ad02 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_dummy.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -13,17 +12,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_dummy import DummyConfig @@ -53,40 +47,18 @@ def extra_repr(self): class DummyRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: DummyConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[DummyConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`DummyRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -199,144 +171,73 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class DummyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DummyConfig, layer_idx: Optional[int] = None): + def __init__(self, config: DummyConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class DummyFlashAttention2(DummyAttention): - """ - Dummy flash attention module. This module inherits from `DummyAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -346,159 +247,30 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (DummyRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class DummySdpaAttention(DummyAttention): - """ - Dummy attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `DummyAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from DummyAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "DummyModel is using DummySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -DUMMY_ATTENTION_CLASSES = { - "eager": DummyAttention, - "flash_attention_2": DummyFlashAttention2, - "sdpa": DummySdpaAttention, -} + return attn_output, attn_weights class DummyDecoderLayer(nn.Module): @@ -506,7 +278,7 @@ def __init__(self, config: DummyConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = DUMMY_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = DummyAttention(config=config, layer_idx=layer_idx) self.mlp = DummyMLP(config) self.input_layernorm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -522,36 +294,14 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -571,13 +321,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -724,10 +470,7 @@ def __init__(self, config: DummyConfig): ) self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DummyRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -744,7 +487,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -772,31 +515,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -805,7 +539,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -838,9 +571,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -850,18 +580,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 562e7dcab2b9f2..c4f90a5cbadab3 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_multimodal1.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -13,17 +12,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_multimodal1 import Multimodal1TextConfig @@ -53,40 +47,18 @@ def extra_repr(self): class Multimodal1TextRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: Multimodal1TextConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Multimodal1TextConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Multimodal1TextRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -199,144 +171,73 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Multimodal1TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: Multimodal1TextConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Multimodal1TextConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Multimodal1TextFlashAttention2(Multimodal1TextAttention): - """ - Multimodal1Text flash attention module. This module inherits from `Multimodal1TextAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -346,159 +247,30 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Multimodal1TextRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Multimodal1TextSdpaAttention(Multimodal1TextAttention): - """ - Multimodal1Text attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Multimodal1TextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Multimodal1TextAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Multimodal1TextModel is using Multimodal1TextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -MULTIMODAL1_TEXT_ATTENTION_CLASSES = { - "eager": Multimodal1TextAttention, - "flash_attention_2": Multimodal1TextFlashAttention2, - "sdpa": Multimodal1TextSdpaAttention, -} + return attn_output, attn_weights class Multimodal1TextDecoderLayer(nn.Module): @@ -506,9 +278,7 @@ def __init__(self, config: Multimodal1TextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MULTIMODAL1_TEXT_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx - ) + self.self_attn = Multimodal1TextAttention(config=config, layer_idx=layer_idx) self.mlp = Multimodal1TextMLP(config) self.input_layernorm = Multimodal1TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -524,36 +294,14 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -573,13 +321,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -726,10 +470,7 @@ def __init__(self, config: Multimodal1TextConfig): ) self.norm = Multimodal1TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Multimodal1TextRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -746,7 +487,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -774,31 +515,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -807,7 +539,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -840,9 +571,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -852,18 +580,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 189e090094c76c..b8d5b5eb910095 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_my_new_model2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -13,15 +12,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_my_new_model2 import MyNewModel2Config @@ -48,24 +44,72 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class MyNewModel2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + class MyNewModel2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + config: MyNewModel2Config, + device=None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -73,31 +117,12 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling -class MyNewModel2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_activation is None: - logger.warning_once( - "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" - "MyNewModel2's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" - "`config.hidden_activation` if you want to override this behaviour.\n" - "See https://github.com/huggingface/transformers/pull/29402 for more details." - ) - config.hidden_activation = "gelu_pytorch_tanh" - hidden_activation = config.hidden_activation - self.act_fn = ACT2FN[hidden_activation] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): @@ -146,241 +171,75 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class MyNewModel2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: MyNewModel2Config, layer_idx: Optional[int] = None): + def __init__(self, config: MyNewModel2Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = MyNewModel2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MyNewModel2SdpaAttention(MyNewModel2Attention): - """ - MyNewModel2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MyNewModel2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MyNewModel2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MyNewModel2Model is using MyNewModel2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -class MyNewModel2FlashAttention2(MyNewModel2Attention): - """ - MyNewModel2 flash attention module. This module inherits from `MyNewModel2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -388,75 +247,39 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (MyNewModel2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value - - -MY_NEW_MODEL2_ATTENTION_CLASSES = { - "eager": MyNewModel2Attention, - "flash_attention_2": MyNewModel2FlashAttention2, - "sdpa": MyNewModel2SdpaAttention, -} + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights class MyNewModel2DecoderLayer(nn.Module): def __init__(self, config: MyNewModel2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MY_NEW_MODEL2_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx - ) + + self.self_attn = MyNewModel2Attention(config=config, layer_idx=layer_idx) + self.mlp = MyNewModel2MLP(config) self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -470,33 +293,15 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -504,6 +309,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -515,13 +321,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -667,10 +469,8 @@ def __init__(self, config: MyNewModel2Config): [MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.rotary_emb = MyNewModel2RotaryEmbedding(config=config) self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -714,19 +514,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False # noqa: F841 - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True # noqa: F841 - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -744,6 +533,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # MyNewModel2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -753,7 +545,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -769,6 +560,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -779,13 +571,11 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -795,18 +585,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index d303d328e887d6..477d084b1d9309 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -10,7 +10,7 @@ import torch from torch import nn -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -253,7 +253,14 @@ def tie_weights(self): return self.language_model.tie_weights() def _update_causal_mask( - self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_ids=None, + inputs_embeds=None, + is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: @@ -261,11 +268,13 @@ def _update_causal_mask( return None using_static_cache = isinstance(past_key_values, StaticCache) - dtype = inputs_embeds.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = inputs_embeds.shape[1] + min_dtype = torch.finfo(self.dtype).min + inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -278,7 +287,7 @@ def _update_causal_mask( return attention_mask causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below if sequence_length != 1: @@ -288,7 +297,7 @@ def _update_causal_mask( causal_mask[:, :sequence_length] = 0.0 causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] @@ -317,7 +326,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor): image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) - image_features = image_features / (self.config.hidden_size**0.5) + image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features @add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING) @@ -414,6 +423,7 @@ def prepare_inputs_for_generation( token_type_ids=None, use_cache=True, num_logits_to_keep=None, + labels=None, **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling @@ -433,12 +443,16 @@ def prepare_inputs_for_generation( # position_ids in NewTaskModel are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values - + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training + ) + model_inputs["attention_mask"] = causal_mask return model_inputs def resize_token_embeddings( diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 79e5ab15a5eda6..42d8108ee72a68 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_super.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -13,17 +12,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_super import SuperConfig @@ -53,40 +47,18 @@ def extra_repr(self): class SuperRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: SuperConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[SuperConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -199,144 +171,73 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class SuperAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None): + def __init__(self, config: SuperConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class SuperFlashAttention2(SuperAttention): - """ - Super flash attention module. This module inherits from `SuperAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -346,159 +247,30 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (SuperRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class SuperSdpaAttention(SuperAttention): - """ - Super attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `SuperAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from SuperAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "SuperModel is using SuperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -SUPER_ATTENTION_CLASSES = { - "eager": SuperAttention, - "flash_attention_2": SuperFlashAttention2, - "sdpa": SuperSdpaAttention, -} + return attn_output, attn_weights class SuperDecoderLayer(nn.Module): @@ -506,7 +278,7 @@ def __init__(self, config: SuperConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = SUPER_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = SuperAttention(config=config, layer_idx=layer_idx) self.mlp = SuperMLP(config) self.input_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -522,36 +294,14 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -571,13 +321,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -724,10 +470,7 @@ def __init__(self, config: SuperConfig): ) self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = SuperRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index a04b7bd6aa1b6d..648877c8dce962 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -37,10 +37,10 @@ download_url, extract_commit_hash, is_remote_url, - is_timm_config_dict, is_torch_available, logging, ) +from .utils.generic import is_timm_config_dict logger = logging.get_logger(__name__) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py new file mode 100644 index 00000000000000..1be223f8b079ba --- /dev/null +++ b/src/transformers/integrations/flash_attention.py @@ -0,0 +1,52 @@ +from typing import Optional, Tuple + +import torch + +from ..modeling_flash_attention_utils import _flash_attention_forward +from ..utils import is_flash_attn_greater_or_equal_2_10 + + +_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + +def flash_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + # This is before the transpose + seq_len = query.shape[2] + + # FA2 uses non-transposed inputs + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if query.dtype == torch.float32: + query = query.to(torch.float16) + key = key.to(torch.float16) + value = value.to(torch.float16) + + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + seq_len, + module.is_causal, + dropout=dropout, + softmax_scale=scaling, + sliding_window=sliding_window, + softcap=softcap, + use_top_left_mask=_use_top_left_mask, + **kwargs, + ) + + return attn_output, None diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py new file mode 100644 index 00000000000000..eacfb2b568b55b --- /dev/null +++ b/src/transformers/integrations/flex_attention.py @@ -0,0 +1,44 @@ +from typing import Optional, Tuple + +import torch + +from ..utils import is_torch_greater_or_equal + + +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention + + +def flex_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def causal_mod(score, b, h, q_idx, kv_idx): + if softcap is not None: + score = softcap * torch.tanh(score / softcap) + if causal_mask is not None: + score += causal_mask[b][0][q_idx][kv_idx] + return score + + attn_output, attention_weights = flex_attention( + query, + key, + value, + score_mod=causal_mod, + enable_gqa=True, + scale=scaling, + return_lse=True, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attention_weights diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py new file mode 100644 index 00000000000000..265260c9b79e4c --- /dev/null +++ b/src/transformers/integrations/sdpa_attention.py @@ -0,0 +1,55 @@ +from typing import Optional, Tuple + +import torch + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + if is_causal is None: + is_causal = causal_mask is None and query.shape[2] > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index ec03ba1eb5fd83..6adda0036cc096 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -247,6 +247,7 @@ def _flash_attention_forward( max_length_q: Optional[int] = None, max_length_k: Optional[int] = None, target_dtype: Optional[torch.dtype] = None, + **kwargs, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -276,7 +277,7 @@ def _flash_attention_forward( if not use_top_left_mask: causal = is_causal else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. causal = is_causal and query_length != 1 # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2ea88fb9b05b90..9dcd6d758ecbe7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -45,6 +45,9 @@ from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from .integrations.flash_attention import flash_attention_forward +from .integrations.flex_attention import flex_attention_forward +from .integrations.sdpa_attention import sdpa_attention_forward from .loss.loss_utils import LOSS_MAPPING from .pytorch_utils import ( # noqa: F401 Conv1D, @@ -171,10 +174,8 @@ def is_local_dist_rank_0(): if is_peft_available(): from .utils import find_adapter_config_file - SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") - TORCH_INIT_FUNCTIONS = { "uniform_": nn.init.uniform_, "normal_": nn.init.normal_, @@ -5634,3 +5635,14 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): files_content[filename].append(device_map[weight_name]) return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] + + +ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {} + +ALL_ATTENTION_FUNCTIONS.update( + { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "sdpa": sdpa_attention_forward, + } +) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c3e3e424a4baa4..6481d6f3c434c7 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -18,24 +18,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -478,144 +476,73 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class AriaTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: AriaTextConfig, layer_idx: Optional[int] = None): + def __init__(self, config: AriaTextConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class AriaTextFlashAttention2(AriaTextAttention): - """ - AriaText flash attention module. This module inherits from `AriaTextAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -625,159 +552,30 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (AriaTextRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class AriaTextSdpaAttention(AriaTextAttention): - """ - AriaText attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `AriaTextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from AriaTextAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "AriaTextModel is using AriaTextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -ARIA_TEXT_ATTENTION_CLASSES = { - "eager": AriaTextAttention, - "flash_attention_2": AriaTextFlashAttention2, - "sdpa": AriaTextSdpaAttention, -} + return attn_output, attn_weights class AriaTextDecoderLayer(nn.Module): @@ -797,7 +595,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = ARIA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = AriaTextAttention(config=config, layer_idx=layer_idx) self.mlp = AriaTextMoELayer(config) self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -812,36 +610,14 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -861,13 +637,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -953,40 +725,18 @@ def _init_weights(self, module): class AriaTextRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: AriaTextConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[AriaTextConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`AriaTextRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -1136,8 +886,6 @@ def __init__(self, config: AriaTextConfig): self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = AriaTextRotaryEmbedding(config=config) self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -1154,7 +902,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1182,31 +930,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -1215,7 +954,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -1248,9 +986,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1260,18 +995,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 9e225ac9ae15c0..36a278263b558a 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -197,7 +197,6 @@ class BarkSelfFlashAttention2(BarkSelfAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index dd1b69c8127fb8..4e1f0b389d42ea 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -294,7 +294,6 @@ class BartFlashAttention2(BartAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index f01665201bfa21..11bc411a00c005 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -362,7 +362,7 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon +# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon # TODO(joao): add me back asap :) class ChameleonFlashAttention2(ChameleonAttention): """ diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 4751bb91aace29..0bd9c9c0abce2f 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -401,7 +401,6 @@ class CLIPFlashAttention2(CLIPAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index b9a235ed500c0c..7b8b9547ac1c33 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -351,7 +351,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere +# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere +# TODO cyril: modular class CohereFlashAttention2(CohereAttention): """ Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays @@ -760,7 +761,8 @@ def _init_weights(self, module): "The bare Cohere Model outputting raw hidden-states without any specific head on top.", COHERE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere, LLAMA->COHERE +# copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere, LLAMA->COHERE +# TODO cyril: modular class CohereModel(CoherePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`] @@ -826,31 +828,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -859,7 +852,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -892,9 +884,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -904,18 +893,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 6b19d178341fbb..1ffa4bffddc3df 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -659,11 +659,8 @@ def __init__(self, config: Cohere2Config): [Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) - - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") self.rotary_emb = Cohere2RotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 590509eaf9057c..03102d22ca0d77 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -489,7 +489,6 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 7d20b766658f23..0d2c4297e0d473 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -46,7 +46,6 @@ _CONFIG_FOR_DOC = "DbrxConfig" -# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx class DbrxRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -318,7 +317,6 @@ class DbrxFlashAttention2(DbrxAttention): calls the public API of flash attention. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index b8eb9f5a8b4222..60fea55d87be5d 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -17,7 +17,7 @@ import math import os from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,7 +25,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, @@ -100,6 +100,49 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model +# Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward +def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if module.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if module.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(module.layer_idx + 1) + + if not module.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = module.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) + + return attn_output, attn_weights + + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2 class DecisionTransformerGPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): @@ -161,46 +204,6 @@ def prune_heads(self, heads): self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device - ) - - # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) - - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() @@ -250,25 +253,10 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights - def _split_heads(self, tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - - def _merge_heads(self, tensor, num_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden_size - """ - tensor = tensor.permute(0, 2, 1, 3).contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], @@ -279,6 +267,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): @@ -287,32 +276,65 @@ def forward( "Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`." ) - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + query_states = self.q_attn(hidden_states) + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) + + shape_q = (*query_states.shape[:-1], -1, self.head_dim) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + query_states = query_states.reshape(shape_q).transpose(1, 2) + key_states = key_states.reshape(shape_kv).transpose(1, 2) + value_states = value_states.reshape(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) + key_states = torch.cat((past_key, key_states), dim=-2) + value_states = torch.cat((past_value, value_states), dim=-2) if use_cache is True: - present = (key, value) + present = (key_states, value_states) else: present = None - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + is_cross_attention = encoder_hidden_states is not None + is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention + + using_eager = self.config._attn_implementation == "eager" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): + using_eager = True + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + # Attention functions are consistent with previous equivalent attention classes, however they do not support some options + # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but + # not necessarily to eager (if mentionned options are provided). + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if using_eager and self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn( + query_states, key_states, value_states, attention_mask, head_mask + ) else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + head_mask=head_mask, + dropout=self.attn_dropout.p if self.training else 0.0, + is_causal=is_causal, + **kwargs, + ) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 36e35594b3d3c6..a826272956e503 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -245,7 +245,6 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 51d9ff39d48f88..8d5a224f4f6654 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -113,40 +113,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class FalconRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: FalconConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[FalconConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`FalconRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -492,7 +470,6 @@ class FalconFlashAttention2(FalconAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index b3253fdd5614e1..e2ea12b03fe434 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -19,8 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,19 +28,21 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -74,24 +75,72 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + class GemmaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + config: GemmaConfig, + device=None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -99,60 +148,12 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding): - """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding): - """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - cos, sin = super().forward(x, position_ids) - return cos, sin + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling - -class GemmaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_activation is None: - logger.warning_once( - "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" - "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" - "`config.hidden_activation` if you want to override this behaviour.\n" - "See https://github.com/huggingface/transformers/pull/29402 for more details." - ) - config.hidden_activation = "gelu_pytorch_tanh" - hidden_activation = config.hidden_activation - self.act_fn = ACT2FN[hidden_activation] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): @@ -201,241 +202,75 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class GemmaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = GemmaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GemmaSdpaAttention(GemmaAttention): - """ - Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from GemmaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -class GemmaFlashAttention2(GemmaAttention): - """ - Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -443,73 +278,39 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (GemmaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value - - -GEMMA_ATTENTION_CLASSES = { - "eager": GemmaAttention, - "flash_attention_2": GemmaFlashAttention2, - "sdpa": GemmaSdpaAttention, -} + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights class GemmaDecoderLayer(nn.Module): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) + self.mlp = GemmaMLP(config) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -523,33 +324,15 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -557,6 +340,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -568,13 +352,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -720,10 +500,8 @@ def __init__(self, config: GemmaConfig): [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.rotary_emb = GemmaRotaryEmbedding(config=config) self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -767,19 +545,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False # noqa: F841 - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True # noqa: F841 - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -797,6 +564,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -806,7 +576,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -822,6 +591,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -832,13 +602,11 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -848,18 +616,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -983,6 +746,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1030,7 +796,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1080,6 +846,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1088,7 +855,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 778ef7e19b65b6..29b6f8a1946173 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import sentencepiece as spm @@ -21,24 +20,17 @@ import torch.utils.checkpoint from torch import nn -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...modeling_outputs import BaseModelOutputWithPast from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging from ..llama.modeling_llama import ( - LlamaDecoderLayer, - LlamaFlashAttention2, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, + LlamaMLP, LlamaModel, - LlamaPreTrainedModel, - apply_rotary_pos_emb, - repeat_kv, ) from ..llama.tokenization_llama import LlamaTokenizer @@ -352,472 +344,15 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) - - -class GemmaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding): - """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding): - """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class GemmaMLP(nn.Module): +class GemmaMLP(LlamaMLP): def __init__(self, config): super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_activation is None: - logger.warning_once( - "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" - "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" - "`config.hidden_activation` if you want to override this behaviour.\n" - "See https://github.com/huggingface/transformers/pull/29402 for more details." - ) - config.hidden_activation = "gelu_pytorch_tanh" - hidden_activation = config.hidden_activation - self.act_fn = ACT2FN[hidden_activation] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class GemmaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = GemmaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GemmaSdpaAttention(GemmaAttention): - """ - Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from GemmaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention): - """ - Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (GemmaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -GEMMA_ATTENTION_CLASSES = { - "eager": GemmaAttention, - "flash_attention_2": GemmaFlashAttention2, - "sdpa": GemmaSdpaAttention, -} - - -class GemmaDecoderLayer(LlamaDecoderLayer): - def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__(config) - self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - self.mlp = GemmaMLP(config) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class GemmaPreTrainedModel(LlamaPreTrainedModel): - pass class GemmaModel(LlamaModel): - def __init__(self, config: GemmaConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet! - self.post_init() - def forward( self, input_ids: torch.LongTensor = None, @@ -850,19 +385,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False # noqa: F841 - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True # noqa: F841 - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -880,6 +404,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -889,7 +416,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -905,6 +431,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -915,13 +442,11 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -931,44 +456,33 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() -# Example where we ony modify the docstring and call super class GemmaForCausalLM(LlamaForCausalLM): - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.post_init() - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + def forward(**super_kwargs): r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + ```python >>> from transformers import AutoTokenizer, GemmaForCausalLM @@ -983,59 +497,15 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + return super().forward(**super_kwargs) class GemmaForSequenceClassification(LlamaForSequenceClassification): - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.post_init() + pass class GemmaForTokenClassification(LlamaForTokenClassification): - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.post_init() + pass __all__ = [ @@ -1045,5 +515,5 @@ def __init__(self, config): "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification", - "GemmaPreTrainedModel", + "GemmaPreTrainedModel", # noqa: F822 ] diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 288913697f2641..67fc6c86a3bac6 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -27,32 +27,26 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal, - is_torch_greater_or_equal, logging, replace_return_docstrings, ) from .configuration_gemma2 import Gemma2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - -if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention - logger = logging.get_logger(__name__) @@ -92,35 +86,8 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_activation] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class Gemma2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj def rotate_half(x): @@ -170,266 +137,118 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def eager_attention_forward( - config: Gemma2Config, + module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) + if scaling is None: + scaling = module.head_dim**-0.5 - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) - if config.attn_logit_softcapping is not None: - attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * config.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -def flash_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - target_dtype: torch.dtype = torch.float16, - **_kwargs, -) -> Tuple[torch.Tensor, None]: - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - seq_len, - dropout=dropout_rate, - softmax_scale=config.scaling, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, - ) - - return attn_output, None - - -def flex_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - output_attentions: bool = False, - **_kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = config.attn_logit_softcapping - score = soft_cap * torch.tanh(score / soft_cap) - if mask is not None: - return score + mask[b][0][q_idx][kv_idx] - return score - - attn_output = flex_attention( - query, - key, - value, - score_mod=tanh_softcap, - enable_gqa=True, - scale=config.scaling, - return_lse=output_attentions, - ) - if not output_attentions: - attn_weights = None - else: - attn_output, attn_weights = attn_output - - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def sdpa_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, None]: - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - scale=config.scaling, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None - - -GEMMA2_ATTENTION_FUNCTION = { - "flash_attention_2": flash_attention_forward, - "flex_attention": flex_attention_forward, - "eager": eager_attention_forward, - "sdpa": sdpa_attention_forward, -} - - class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar**-0.5 - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attn_logit_softcapping = config.attn_logit_softcapping - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) + self.attention_dropout = self.config.attention_dropout + self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = Gemma2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: - logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") - attention_type = "flex_attention" - else: - attention_type = self.config._attn_implementation + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.config._attn_implementation = "flash_attention_2" - logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - -class Gemma2SdpaAttention(Gemma2Attention): - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.config._attn_implementation = "sdpa" - logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) + return attn_output, attn_weights class Gemma2DecoderLayer(nn.Module): @@ -450,6 +269,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -476,8 +296,9 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -499,12 +320,74 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs +class Gemma2RotaryEmbedding(nn.Module): + def __init__( + self, + config: Gemma2Config, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + GEMMA2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -535,7 +418,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True - _supports_quantized_cache = False + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): @@ -549,20 +432,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): - """ - Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. - SDPA reduces the model performance on Gemma2 because of the logits softcapping. - """ - config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "sdpa": - config._attn_implementation = "eager" - - return config - GEMMA2_INPUTS_DOCSTRING = r""" Args: @@ -661,10 +530,8 @@ def __init__(self, config: Gemma2Config): [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.rotary_emb = Gemma2RotaryEmbedding(config=config) self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -734,6 +601,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -752,6 +622,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -762,6 +633,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -780,16 +652,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = past_key_values if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 5e04fe1b63a362..48b12411361aff 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,36 +22,27 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from ...utils import ( - is_flash_attn_2_available, - is_flash_attn_greater_or_equal, - is_torch_greater_or_equal, - logging, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging from ..gemma.modeling_gemma import ( + GemmaAttention, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, + GemmaMLP, GemmaModel, - GemmaPreTrainedModel, GemmaRMSNorm, - GemmaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - -if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention - - _CHECKPOINT_FOR_DOC = "google/gemma2-7b" logger = logging.get_logger(__name__) @@ -194,286 +185,106 @@ class Gemma2RMSNorm(GemmaRMSNorm): pass -class Gemma2MLP(nn.Module): +class Gemma2MLP(GemmaMLP): def __init__(self, config): super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_activation] - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): - pass - def eager_attention_forward( - config: Gemma2Config, + module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if config.attn_logit_softcapping is not None: - attn_weights = attn_weights / config.attn_logit_softcapping + if softcap is not None: + attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * config.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -def flash_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - target_dtype: torch.dtype = torch.float16, - **_kwargs, -) -> Tuple[torch.Tensor, None]: - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - seq_len, - dropout=dropout_rate, - softmax_scale=config.scaling, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, - ) - - return attn_output, None - - -def flex_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - output_attentions: bool = False, - **_kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = config.attn_logit_softcapping - score = soft_cap * torch.tanh(score / soft_cap) - if mask is not None: - return score + mask[b][0][q_idx][kv_idx] - return score - - attn_output = flex_attention( - query, - key, - value, - score_mod=tanh_softcap, - enable_gqa=True, - scale=config.scaling, - return_lse=output_attentions, - ) - if not output_attentions: - attn_weights = None - else: - attn_output, attn_weights = attn_output - - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def sdpa_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, None]: - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - scale=config.scaling, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None - - -GEMMA2_ATTENTION_FUNCTION = { - "flash_attention_2": flash_attention_forward, - "flex_attention": flex_attention_forward, - "eager": eager_attention_forward, - "sdpa": sdpa_attention_forward, -} - - -class Gemma2Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta +class Gemma2Attention(GemmaAttention): + def __init__(self, config: Gemma2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.attention_dropout = self.config.attention_dropout self.is_causal = True self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attn_logit_softcapping = config.attn_logit_softcapping - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = Gemma2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: - logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") - attention_type = "flex_attention" - else: - attention_type = self.config._attn_implementation + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.config._attn_implementation = "flash_attention_2" - logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - -class Gemma2SdpaAttention(Gemma2Attention): - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.config._attn_implementation = "sdpa" - logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) + return attn_output, attn_weights class Gemma2DecoderLayer(nn.Module): @@ -494,6 +305,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -520,8 +332,9 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -543,37 +356,15 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs -class Gemma2PreTrainedModel(GemmaPreTrainedModel): - _supports_quantized_cache = False - - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): - """ - Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. - SDPA reduces the model performance on Gemma2 because of the logits softcapping. - """ - config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "sdpa": - config._attn_implementation = "eager" - - return config - - -class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): +class Gemma2Model(GemmaModel): def __init__(self, config: Gemma2Config): super().__init__(config) self.layers = nn.ModuleList( [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.post_init() def forward( self, @@ -633,6 +424,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -651,6 +445,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -661,6 +456,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -679,16 +475,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = past_key_values if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( @@ -909,7 +702,7 @@ def __init__(self, config): "Gemma2Config", "Gemma2ForCausalLM", "Gemma2Model", - "Gemma2PreTrainedModel", + "Gemma2PreTrainedModel", # noqa: F822 "Gemma2ForSequenceClassification", "Gemma2ForTokenClassification", ] diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index b4a292d69de929..95ad0d9719951d 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -19,8 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -29,20 +28,21 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -55,55 +55,6 @@ _CONFIG_FOR_DOC = "GlmConfig" -class GlmRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - GlmRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class GlmRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class GlmMLP(nn.Module): def __init__(self, config): super().__init__() @@ -135,6 +86,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., 0::2] @@ -191,134 +168,38 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.is_causal = True - self.scaling = 1 / math.sqrt(self.head_dim) - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GlmFlashAttention2(GlmAttention): - """ - Glm flash attention module. This module inherits from `GlmAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -328,167 +209,123 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (GlmRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - softmax_scale=self.scaling, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GlmSdpaAttention(GlmAttention): - """ - Glm attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GlmAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from GlmAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GlmModel is using GlmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) +class GlmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GlmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] +class GlmRotaryEmbedding(nn.Module): + def __init__( + self, + config: GlmConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, - ) + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len - attn_output = self.o_proj(attn_output) + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) - return attn_output, None, past_key_value + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling -GLM_ATTENTION_CLASSES = { - "eager": GlmAttention, - "flash_attention_2": GlmFlashAttention2, - "sdpa": GlmSdpaAttention, -} + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class GlmDecoderLayer(nn.Module): - def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): + def __init__(self, config: GlmConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GLM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = GlmAttention(config=config, layer_idx=layer_idx) self.mlp = GlmMLP(config) self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -504,36 +341,14 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -553,13 +368,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -705,14 +516,8 @@ def __init__(self, config: GlmConfig): [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = GlmRotaryEmbedding( - dim=int(config.head_dim * config.partial_rotary_factor), - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) + self.rotary_emb = GlmRotaryEmbedding(config=config) self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -729,7 +534,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -757,31 +562,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -790,7 +586,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -823,9 +618,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -835,18 +627,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -970,11 +757,14 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - def __init__(self, config: GlmConfig): + def __init__(self, config): super().__init__(config) self.model = GlmModel(config) self.vocab_size = config.vocab_size @@ -1017,7 +807,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1038,16 +828,16 @@ def forward( ```python >>> from transformers import AutoTokenizer, GlmForCausalLM - >>> model = GlmForCausalLM.from_pretrained("google/glm-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/glm-7b") + >>> model = GlmForCausalLM.from_pretrained("meta-glm/Glm-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm/Glm-2-7b-hf") - >>> prompt = "What is your favorite condiment?" + >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1067,6 +857,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1075,7 +866,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1106,7 +897,7 @@ def forward( GLM_START_DOCSTRING, ) class GlmForSequenceClassification(GlmPreTrainedModel): - def __init__(self, config: GlmConfig): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = GlmModel(config) @@ -1202,7 +993,7 @@ def forward( GLM_START_DOCSTRING, ) class GlmForTokenClassification(GlmPreTrainedModel): - def __init__(self, config: GlmConfig): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = GlmModel(config) diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 48605c15d30be3..ec07be10fb6a55 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Optional import torch @@ -21,26 +20,13 @@ import torch.utils.checkpoint from ...utils import logging -from ..gemma.modeling_gemma import ( - GemmaForCausalLM, - GemmaForSequenceClassification, - GemmaForTokenClassification, -) -from ..granite.modeling_granite import ( - GraniteAttention, - GraniteFlashAttention2, - GraniteSdpaAttention, -) from ..llama.modeling_llama import ( - LlamaDecoderLayer, - LlamaModel, - LlamaPreTrainedModel, -) -from ..phi3.modeling_phi3 import ( - Phi3MLP, - Phi3RMSNorm, - Phi3RotaryEmbedding, + LlamaAttention, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, ) +from ..phi3.modeling_phi3 import Phi3MLP from .configuration_glm import GlmConfig @@ -49,14 +35,6 @@ _CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" -class GlmRMSNorm(Phi3RMSNorm): - pass - - -class GlmRotaryEmbedding(Phi3RotaryEmbedding): - pass - - class GlmMLP(Phi3MLP): pass @@ -110,83 +88,27 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class GlmAttention(GraniteAttention): +class GlmAttention(LlamaAttention): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.scaling = 1 / math.sqrt(self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) -class GlmFlashAttention2(GlmAttention, GraniteFlashAttention2): +class GlmForCausalLM(LlamaForCausalLM): pass -class GlmSdpaAttention(GraniteSdpaAttention): +class GlmForSequenceClassification(LlamaForSequenceClassification): pass -GLM_ATTENTION_CLASSES = { - "eager": GlmAttention, - "flash_attention_2": GlmFlashAttention2, - "sdpa": GlmSdpaAttention, -} - - -class GlmDecoderLayer(LlamaDecoderLayer): - def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): - super().__init__() - - self.mlp = GlmMLP(config) - self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - -class GlmPreTrainedModel(LlamaPreTrainedModel): +class GlmForTokenClassification(LlamaForTokenClassification): pass -class GlmModel(GlmPreTrainedModel, LlamaModel): - def __init__(self, config: GlmConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = GlmRotaryEmbedding( - dim=int(config.head_dim * config.partial_rotary_factor), - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - -class GlmForCausalLM(GemmaForCausalLM): - def __init__(self, config: GlmConfig): - super().__init__(config) - self.model = GlmModel(config) - self.post_init() - - -class GlmForSequenceClassification(GemmaForSequenceClassification): - def __init__(self, config: GlmConfig): - super().__init__(config) - self.model = GlmModel(config) - self.post_init() - - -class GlmForTokenClassification(GemmaForTokenClassification): - def __init__(self, config: GlmConfig): - super().__init__(config) - self.model = GlmModel(config) - self.post_init() - - __all__ = [ - "GlmPreTrainedModel", - "GlmModel", + "GlmPreTrainedModel", # noqa: F822 + "GlmModel", # noqa: F822 "GlmForCausalLM", "GlmForSequenceClassification", "GlmForTokenClassification", diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 58143192c20482..ad53c7804ebeea 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -19,11 +19,10 @@ import os import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint -from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -37,16 +36,13 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - get_torch_version, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -54,10 +50,6 @@ from .configuration_gpt2 import GPT2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "openai-community/gpt2" @@ -120,6 +112,48 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model +def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if module.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if module.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(module.layer_idx + 1) + + if not module.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = module.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) + + return attn_output, attn_weights + + class GPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -180,46 +214,6 @@ def prune_heads(self, heads): self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device - ) - - # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) - - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() @@ -269,25 +263,10 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights - def _split_heads(self, tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - - def _merge_heads(self, tensor, num_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden_size - """ - tensor = tensor.permute(0, 2, 1, 3).contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], @@ -298,6 +277,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): @@ -306,260 +286,73 @@ def forward( "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." ) - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + query_states = self.q_attn(hidden_states) + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) + + shape_q = (*query_states.shape[:-1], -1, self.head_dim) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + query_states = query_states.reshape(shape_q).transpose(1, 2) + key_states = key_states.reshape(shape_kv).transpose(1, 2) + value_states = value_states.reshape(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) + key_states = torch.cat((past_key, key_states), dim=-2) + value_states = torch.cat((past_value, value_states), dim=-2) if use_cache is True: - present = (key, value) + present = (key_states, value_states) else: present = None - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) - else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - -class GPT2FlashAttention2(GPT2Attention): - """ - GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + is_cross_attention = encoder_hidden_states is not None + is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - bsz, _, _ = hidden_states.size() - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + using_eager = self.config._attn_implementation == "eager" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): + using_eager = True + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - present = None - if use_cache is True: - present = (key, value) - - query_length = query.shape[2] - tgt_len = key.shape[2] - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) - key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) - value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) - - attn_dropout = self.attn_dropout.p if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - if query.dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.c_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) - - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) - attn_output = self.c_proj(attn_weights_reshaped) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights_reshaped,) - - return outputs - - -class GPT2SdpaAttention(GPT2Attention): - """ - GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass - to adapt to the SDPA API. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__ - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + # Attention functions are consistent with previous equivalent attention classes, however they do not support some options + # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but + # not necessarily to eager (if mentionned options are provided). + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if using_eager and self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn( + query_states, key_states, value_states, attention_mask, head_mask ) - return super().forward( - hidden_states=hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, + else: + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + dropout=self.attn_dropout.p if self.training else 0.0, + is_causal=is_causal, + **kwargs, ) - bsz, q_len, _ = hidden_states.size() - - # Initial attention projections - is_cross_attention = encoder_hidden_states is not None - if is_cross_attention: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - # Optional kv caching - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - present = None - if use_cache is True: - present = (key, value) - - # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA - if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=self.attn_dropout.p if self.training else 0.0, - is_causal=is_causal, - ) - - # Reshape outputs - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.embed_dim) - - # Final projection + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - return attn_output, present, None + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) class GPT2MLP(nn.Module): @@ -579,22 +372,18 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention} - - class GPT2Block(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size - attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = attention_class(config=config, layer_idx=layer_idx) + self.attn = GPT2Attention(config=config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: - self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 5326c7b907d4b1..403159cdf39c9a 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -278,7 +278,6 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 28bfbabc1fd8e0..6763695bfba036 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -278,7 +278,6 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 70ff07ed7f6dcc..7152d72f5b7fc8 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -490,40 +490,18 @@ def __init__(self, config, layer_idx=None): class GPTNeoXRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: GPTNeoXConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[GPTNeoXConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`GPTNeoXRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index c9e1b2d7213587..71602f01e7d6f8 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -227,40 +227,18 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): class GPTNeoXJapaneseRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: GPTNeoXJapaneseConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[GPTNeoXJapaneseConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`GPTNeoXJapaneseRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 1cc9cf369d1887..4af8f73b5f5eea 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -266,7 +266,6 @@ class GPTJFlashAttention2(GPTJAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 8cd24265d9edcf..2e045e149d95de 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granite/modular_granite.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granite.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. # @@ -13,29 +19,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -43,96 +44,9 @@ logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "GraniteConfig" -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Granite -class GraniteRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - GraniteRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -ALL_LAYERNORM_LAYERS.append(GraniteRMSNorm) - - -class GraniteRotaryEmbedding(nn.Module): - def __init__(self, config: GraniteConfig): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device=None, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half with Llama->Granite def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -140,7 +54,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb with Llama->Granite def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -168,24 +81,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class GraniteMLP(nn.Module): - # Copied from transformers.models.llama.modeling_llama.LlamaMLP.__init__ with Llama->Granite - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - # Copied from transformers.models.gemma.modeling_gemma.GemmaMLP.forward with Gemma->Granite - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv with Llama->Granite def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -198,6 +93,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class GraniteAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -205,135 +126,40 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.attention_multiplier self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.is_causal = True - self.scaling = config.attention_multiplier - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GraniteFlashAttention2(GraniteAttention): - """ - Granite flash attention module. This module inherits from `GraniteAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -343,172 +169,77 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (GraniteRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - softmax_scale=self.scaling, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights -class GraniteSdpaAttention(GraniteAttention): - """ - Granite attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GraniteAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from GraniteAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GraniteModel is using GraniteSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, - ) +class GraniteRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GraniteRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) - attn_output = self.o_proj(attn_output) + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - return attn_output, None, past_key_value +class GraniteMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] -GRANITE_ATTENTION_CLASSES = { - "eager": GraniteAttention, - "flash_attention_2": GraniteFlashAttention2, - "sdpa": GraniteSdpaAttention, -} + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class GraniteDecoderLayer(nn.Module): def __init__(self, config: GraniteConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - - self.self_attn = GRANITE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx) self.mlp = GraniteMLP(config) self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.residual_multiplier = config.residual_multiplier def forward( @@ -550,7 +281,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -567,19 +298,81 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs +class GraniteRotaryEmbedding(nn.Module): + def __init__( + self, + config: GraniteConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + GRANITE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -601,7 +394,6 @@ def forward( "The bare Granite Model outputting raw hidden-states without any specific head on top.", GRANITE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Granite class GranitePreTrainedModel(PreTrainedModel): config_class = GraniteConfig base_model_prefix = "model" @@ -723,17 +515,9 @@ def __init__(self, config: GraniteConfig): [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GraniteRotaryEmbedding(config=config) self.gradient_checkpointing = False - self.embedding_multiplier = config.embedding_multiplier - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - - # rope - self.rotary_emb = GraniteRotaryEmbedding(config) # Initialize weights and apply final processing self.post_init() @@ -750,13 +534,14 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -777,27 +562,17 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds * self.embedding_multiplier + inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -805,7 +580,6 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # embed positions hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -814,9 +588,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -842,13 +615,11 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -858,18 +629,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -879,11 +645,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -906,7 +667,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -917,24 +677,17 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - if attention_mask is not None and attention_mask.dim() == 4: - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -944,12 +697,12 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1006,10 +759,13 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Granite def __init__(self, config): super().__init__(config) self.model = GraniteModel(config) @@ -1052,6 +808,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1060,6 +818,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1067,8 +830,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, GraniteForCausalLM - >>> model = GraniteForCausalLM.from_pretrained("ibm/PowerLM-3b") - >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b") + >>> model = GraniteForCausalLM.from_pretrained("meta-granite/Granite-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite/Granite-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1096,26 +859,17 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits / self.config.logits_scaling + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = logits / self.config.logits_scaling # main diff with Llama loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1128,12 +882,3 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py new file mode 100644 index 00000000000000..698280085f1852 --- /dev/null +++ b/src/transformers/models/granite/modular_granite.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache, DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...processing_utils import Unpack +from ...utils import LossKwargs, logging +from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel +from .configuration_granite import GraniteConfig + + +logger = logging.get_logger(__name__) + + +class GraniteAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.scaling = config.attention_multiplier + + +class GraniteDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: GraniteConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.residual_multiplier = config.residual_multiplier + self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class GraniteModel(LlamaModel): + def __init__(self, config: GraniteConfig): + super().__init__(config) + self.embedding_multiplier = config.embedding_multiplier + self.layers = nn.ModuleList( + [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class GraniteForCausalLM(LlamaForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = logits / self.config.logits_scaling # main diff with Llama + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 9f5fdeea07d4b1..1c4c06bbc8d71e 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -158,11 +158,15 @@ def extra_repr(self): # Copied from transformers.models.granite.modeling_granite.GraniteRotaryEmbedding with Granite->GraniteMoe class GraniteMoeRotaryEmbedding(nn.Module): - def __init__(self, config: GraniteMoeConfig): + def __init__( + self, + config: GraniteMoeConfig, + device=None, + ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config.rope_scaling is not None: + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" @@ -172,7 +176,7 @@ def __init__(self, config: GraniteMoeConfig): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device=None, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -413,7 +417,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoe +# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoe +# no longer copied after attention refactors class GraniteMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -510,7 +515,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe +# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe +# TODO cyril: modular class GraniteMoeFlashAttention2(GraniteMoeAttention): """ GraniteMoe flash attention module. This module inherits from `GraniteMoeAttention` as the weights of the module stays @@ -617,7 +623,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe +# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe +# TODO cyril: modular class GraniteMoeSdpaAttention(GraniteMoeAttention): """ GraniteMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 03904a6abfa08b..1629f7d4f3feae 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -563,7 +563,6 @@ class HubertFlashAttention2(HubertAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 8bd24728b03885..b2ffbcbc695696 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -444,7 +444,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 3d46c3bd82e788..6d7295b5120d29 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -272,7 +272,6 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -859,7 +858,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2 +# NO LONGER EXIST Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2 +# TODO cyril: modular class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): """ Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays @@ -867,7 +867,6 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 31d43948fbd565..3a52b8b6d54d0e 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -273,7 +273,6 @@ class Idefics3VisionFlashAttention2(Idefics3VisionAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index a185d5ebc6e86c..ae7470d789b27e 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -384,7 +384,6 @@ class JambaFlashAttention2(JambaAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -835,6 +834,7 @@ def forward( class JambaMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -842,8 +842,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index a4bb1d78fdc5ce..7b7fd5a90d69ed 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -32,6 +32,7 @@ MoeModelOutputWithPast, SequenceClassifierOutputWithPast, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -385,24 +386,55 @@ def extra_repr(self): # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->JetMoe class JetMoeRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + config: JetMoeConfig, + device=None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -410,6 +442,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -486,11 +523,7 @@ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False) - self.rotary_emb = JetMoeRotaryEmbedding( - config.kv_channels, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) + self.rotary_emb = JetMoeRotaryEmbedding(config) def forward( self, @@ -641,7 +674,6 @@ def forward( class JetMoeFlashAttention2(JetMoeAttention): - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8e06098b04c63a..5be33c26414cd7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -17,8 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -28,7 +27,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -37,7 +36,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -45,7 +44,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -84,40 +82,18 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: LlamaConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[LlamaConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -230,144 +206,73 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaFlashAttention2(LlamaAttention): - """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -377,159 +282,30 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaSdpaAttention(LlamaAttention): - """ - Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from LlamaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -LLAMA_ATTENTION_CLASSES = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, -} + return attn_output, attn_weights class LlamaDecoderLayer(nn.Module): @@ -537,7 +313,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -553,36 +329,14 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -602,13 +356,9 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -755,10 +505,7 @@ def __init__(self, config: LlamaConfig): ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -775,7 +522,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -803,31 +550,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -836,7 +574,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -869,9 +606,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -881,18 +615,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index cc35a3504255bf..4e116e7e3db585 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -348,7 +348,6 @@ class M2M100FlashAttention2(M2M100Attention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 95cd7c65ef32c2..e272c98f06975a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -291,7 +291,6 @@ class MBartFlashAttention2(MBartAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index cbdd2c663c5844..1440ce1e075c95 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -364,24 +365,55 @@ def forward(self, x: torch.Tensor): # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi class MimiRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + config: MimiConfig, + device=None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -389,6 +421,11 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -457,7 +494,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi +# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi +# no longer copied after attention refactors class MimiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -493,11 +531,7 @@ def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = MimiRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = MimiRotaryEmbedding(config) self.sliding_window = config.sliding_window # Ignore copy def forward( @@ -559,7 +593,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi +# TODO cyril: modular class MimiFlashAttention2(MimiAttention): """ Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays @@ -670,7 +705,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi +# TODO cyril: modular class MimiSdpaAttention(MimiAttention): """ Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 6ed8178ed9821e..90c38895b4280b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1,36 +1,19 @@ -# coding=utf-8 -# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Mistral model.""" - -import math -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mistral/modular_mistral.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mistral.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -38,79 +21,42 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_mistral import MistralConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" _CONFIG_FOR_DOC = "MistralConfig" -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral -class MistralRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - MistralRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class MistralRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): +class MistralMLP(nn.Module): + def __init__(self, config): super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -118,7 +64,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -146,21 +91,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class MistralMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -173,65 +103,66 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class MistralAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = MistralRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -239,249 +170,58 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MistralFlashAttention2(MistralAttention): - """ - Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ): - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral -# TODO(joao): add me back asap :) -class MistralSdpaAttention(MistralAttention): - """ - Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MistralAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) -MISTRAL_ATTENTION_CLASSES = { - "eager": MistralAttention, - "flash_attention_2": MistralFlashAttention2, - "sdpa": MistralSdpaAttention, -} + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL -# TODO(joao): add me back asap :) class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - - self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - + self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -495,33 +235,15 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -529,6 +251,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -540,16 +263,77 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs +class MistralRotaryEmbedding(nn.Module): + def __init__( + self, + config: MistralConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + MISTRAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -576,10 +360,11 @@ class MistralPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MistralDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): @@ -663,7 +448,7 @@ def _init_weights(self, module): return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices indicating the position of the input sequence tokens in the sequence. Unlike `position_ids`, + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @@ -690,10 +475,10 @@ def __init__(self, config: MistralConfig): self.layers = nn.ModuleList( [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = config._attn_implementation self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.rotary_emb = MistralRotaryEmbedding(config=config) self.gradient_checkpointing = False + # Initialize weights and apply final processing self.post_init() @@ -709,48 +494,36 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -762,17 +535,19 @@ def forward( position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -786,6 +561,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -796,13 +572,12 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -812,18 +587,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -831,11 +601,10 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - use_cache: bool, output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and use_cache: + if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( @@ -977,6 +746,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1024,6 +796,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1044,8 +817,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, MistralForCausalLM - >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1055,7 +828,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1074,6 +846,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1082,18 +855,7 @@ def forward( loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Ensure tensors are on the same device - shift_labels = shift_labels.to(shift_logits.device) - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1110,26 +872,24 @@ def forward( @add_start_docstrings( """ - The Mistral Model transformer with a sequence classification head on top (linear layer). - - [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). + The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. """, MISTRAL_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL -class MistralForSequenceClassification(MistralPreTrainedModel): +class MistralForTokenClassification(MistralPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = MistralModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @@ -1141,19 +901,24 @@ def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1162,7 +927,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.model( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1173,67 +938,47 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + loss = self.loss_function(logits, labels, self.config) if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutputWithPast( + return TokenClassifierOutput( loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @add_start_docstrings( """ - The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). """, MISTRAL_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL -class MistralForTokenClassification(MistralPreTrainedModel): +class MistralForSequenceClassification(MistralPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = MistralModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() @@ -1245,24 +990,19 @@ def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1271,7 +1011,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1282,23 +1022,43 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.config) + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) if not return_dict: - output = (logits,) + outputs[2:] + output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( + return SequenceClassifierOutputWithPast( loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, ) @@ -1309,15 +1069,13 @@ def forward( """, MISTRAL_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model class MistralForQuestionAnswering(MistralPreTrainedModel): base_model_prefix = "model" - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mistral def __init__(self, config): super().__init__(config) - self.model = MistralModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = MistralModel(config) # diff with Llama: transformer->model # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py new file mode 100644 index 00000000000000..362233a21b70f4 --- /dev/null +++ b/src/transformers/models/mistral/modular_mistral.py @@ -0,0 +1,350 @@ +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import QuestionAnsweringModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaMLP, + LlamaModel, + apply_rotary_pos_emb, + eager_attention_forward, +) +from .configuration_mistral import MistralConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" + + +class MistralMLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class MistralAttention(LlamaAttention): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MistralDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) + self.mlp = MistralMLP(config) + + +class MistralModel(LlamaModel): + def __init__(self, config: MistralConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: MistralConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`MistralConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class MistralForCausalLM(LlamaForCausalLM): + pass + + +class MistralForTokenClassification(LlamaForTokenClassification): + pass + + +class MistralForSequenceClassification(LlamaForSequenceClassification): + pass + + +class MistralForQuestionAnswering(LlamaForQuestionAnswering): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) # diff with Llama: transformer->model + del self.transformer + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 0f04ef255c431d..84ed327d9be920 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mixtral/modular_mixtral.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mixtral.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. # @@ -17,142 +23,133 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Mixtral model.""" -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, logging, replace_return_docstrings, ) -from ...utils.import_utils import is_torch_fx_available from .configuration_mixtral import MixtralConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx - - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1" _CONFIG_FOR_DOC = "MixtralConfig" -def load_balancing_loss_func( - gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], - num_experts: Optional[int] = None, - top_k=2, - attention_mask: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, int]: - r""" - Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss - function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between - experts is too unbalanced. + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - Args: - gate_logits: - Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of - shape [batch_size X sequence_length, num_experts]. - num_experts: - Number of experts - top_k: - The number of experts to route per-token, can be also interpreted as the `top-k` routing - parameter. - attention_mask (`torch.Tensor`, *optional*): - The attention_mask used in forward function - shape [batch_size X sequence_length] if not None. + self.act_fn = ACT2FN[config.hidden_act] - Returns: - The auxiliary loss. + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. """ - if gate_logits is None or not isinstance(gate_logits, tuple): - return 0 - if isinstance(gate_logits, tuple): - compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + # Jitter parameters + self.jitter_noise = config.router_jitter_noise - if attention_mask is None: - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.mean(routing_weights, dim=0) - else: - batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) - # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask - expert_attention_mask = ( - attention_mask[None, :, :, None, None] - .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) - .reshape(-1, top_k, num_experts) - .to(compute_device) + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 - ) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert - router_per_expert_attention_mask = ( - attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) - .to(compute_device) - ) + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( - router_per_expert_attention_mask, dim=0 - ) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) - return overall_loss * num_experts + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral class MixtralRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -173,45 +170,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache -class MixtralRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -219,9 +177,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -229,9 +185,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -242,14 +197,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -262,412 +216,98 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class MixtralAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = MixtralRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache -class MixtralFlashAttention2(MixtralAttention): - """ - Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len - ) - - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache -class MixtralSdpaAttention(MixtralAttention): - """ - Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MixtralAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -MIXTRAL_ATTENTION_CLASSES = { - "eager": MixtralAttention, - "flash_attention_2": MixtralFlashAttention2, - "sdpa": MixtralSdpaAttention, -} - - -class MixtralBlockSparseTop2MLP(nn.Module): - def __init__(self, config: MixtralConfig): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralSparseMoeBlock(nn.Module): - """ - This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accommodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - - def __init__(self, config): - super().__init__() - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - - # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - - self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - - # Jitter parameters - self.jitter_noise = config.router_jitter_noise - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ - batch_size, sequence_length, hidden_dim = hidden_states.shape - if self.training and self.jitter_noise > 0: - hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits + return attn_output, attn_weights class MixtralDecoderLayer(nn.Module): @@ -675,7 +315,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = MixtralAttention(config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -691,7 +331,8 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -720,14 +361,16 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -742,15 +385,77 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) return outputs +class MixtralRotaryEmbedding(nn.Module): + def __init__( + self, + config: MixtralConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + MIXTRAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -772,17 +477,17 @@ def forward( "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral -# TODO (Raushan): bring back copied after compile compatibility class MixtralPreTrainedModel(PreTrainedModel): config_class = MixtralConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MixtralDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range @@ -817,7 +522,7 @@ def _init_weights(self, module): Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] @@ -831,17 +536,24 @@ def _init_weights(self, module): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the @@ -855,9 +567,6 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, and - should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): @@ -871,8 +580,6 @@ def _init_weights(self, module): "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache class MixtralModel(MixtralPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] @@ -890,10 +597,10 @@ def __init__(self, config: MixtralConfig): self.layers = nn.ModuleList( [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = config._attn_implementation self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.rotary_emb = MixtralRotaryEmbedding(config=config) self.gradient_checkpointing = False + # Initialize weights and apply final processing self.post_init() @@ -903,7 +610,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Ignore copy @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( self, @@ -918,7 +624,8 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -940,19 +647,8 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -971,11 +667,13 @@ def forward( hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -992,6 +690,7 @@ def forward( output_router_logits, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -1003,13 +702,12 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1022,25 +720,15 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( + output = MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) + return output if return_dict else output.to_tuple() - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1050,6 +738,14 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None @@ -1117,7 +813,6 @@ def _update_causal_mask( return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mixtral def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1185,8 +880,94 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) @@ -1196,6 +977,7 @@ def __init__(self, config): self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing self.post_init() @@ -1218,8 +1000,7 @@ def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -1235,8 +1016,8 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1291,6 +1072,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1299,7 +1081,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: @@ -1344,7 +1126,6 @@ def forward( """, MIXTRAL_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL class MixtralForSequenceClassification(MixtralPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1441,7 +1222,6 @@ def forward( """, MIXTRAL_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL class MixtralForTokenClassification(MixtralPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1530,15 +1310,13 @@ def forward( """, MIXTRAL_START_DOCSTRING, ) -# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Mixtral, MISTRAL->MIXTRAL class MixtralForQuestionAnswering(MixtralPreTrainedModel): base_model_prefix = "model" - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mixtral def __init__(self, config): super().__init__(config) - self.model = MixtralModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = MixtralModel(config) # diff with Llama: transformer->model # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py new file mode 100644 index 00000000000000..a6069f69b33421 --- /dev/null +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -0,0 +1,574 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mixtral model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + logging, +) +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralForCausalLM, + MistralForQuestionAnswering, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, + MistralRMSNorm, +) +from .configuration_mixtral import MixtralConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1" +_CONFIG_FOR_DOC = "MixtralConfig" + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralRMSNorm(MistralRMSNorm): + pass + + +class MixtralAttention(MistralAttention): + pass + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MixtralAttention(config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class MixtralModel(MistralModel): + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class MixtralForCausalLM(MistralForCausalLM): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class MixtralForSequenceClassification(MistralForSequenceClassification): + pass + + +class MixtralForTokenClassification(MistralForTokenClassification): + pass + + +class MixtralForQuestionAnswering(MistralForQuestionAnswering): + pass diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index d53a80dd892901..3e0c4d7a5123a7 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -829,7 +829,8 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 82abfa66c2e837..f0281f57cf1c75 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -36,6 +36,7 @@ ModelOutput, Seq2SeqLMOutput, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -307,24 +308,55 @@ def forward(self, x, layer_idx=None): # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi class MoshiRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + config: MoshiConfig, + device=None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -332,6 +364,11 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -456,13 +493,10 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_fle self.rotary_emb = None if use_rope: self.rope_theta = config.rope_theta - self.rotary_emb = MoshiRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = MoshiRotaryEmbedding(config) - # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward + # copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward + # no longer copied after attention refactors def forward( self, hidden_states: torch.Tensor, @@ -527,7 +561,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi +# TODO cyril: modular class MoshiFlashAttention2(MoshiAttention): """ Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays @@ -643,7 +678,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi +# TODO cyril: modular class MoshiSdpaAttention(MoshiAttention): """ Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 109ddfb626d26b..f83bccb7e4f6f3 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -324,7 +324,6 @@ class MusicgenFlashAttention2(MusicgenAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 61f2ce414e1ddf..dc0e9b882b20cf 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -340,7 +340,6 @@ class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 78dace1a53ce55..a0a10bdc6f3550 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -301,7 +301,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# TODO cyril: modular class NemotronFlashAttention2(NemotronAttention): """ Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays @@ -415,7 +416,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# TODO cyril: modular class NemotronSdpaAttention(NemotronAttention): """ Nemotron attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -514,7 +516,8 @@ def forward( } -# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# no longer copied after attention refactors class NemotronDecoderLayer(nn.Module): # Ignore copy def __init__(self, config: NemotronConfig, layer_idx: int): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 8b40c41e34dcd3..11d3d99f4f72c9 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1,59 +1,35 @@ -# coding=utf-8 -# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch OLMo model.""" - -import math -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo/modular_olmo.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union import torch +import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_olmo import OlmoConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "OlmoConfig" @@ -71,70 +47,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) -ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm) - - -# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo -# TODO(joao): add me back asap :) -class OlmoRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): +class OlmoMLP(nn.Module): + def __init__(self, config): super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding): - """OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding): - """OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] - cos, sin = super().forward(x, position_ids) - return cos, sin + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -142,7 +70,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -170,22 +97,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class OlmoMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -198,167 +109,69 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class OlmoAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo - # TODO(joao): add me back asap :) - def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None): + def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = OlmoRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = OlmoLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = OlmoDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.config.clip_qkv is not None: - query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class OlmoFlashAttention2(OlmoAttention): - """ - OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -369,14 +182,11 @@ def forward( key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -384,174 +194,42 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (OlmoRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class OlmoSdpaAttention(OlmoAttention): - """ - OLMo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `OlmoAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from OlmoAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "OlmoModel is using OlmoSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.config.clip_qkv is not None: - query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -OLMO_ATTENTION_CLASSES = { - "eager": OlmoAttention, - "flash_attention_2": OlmoFlashAttention2, - "sdpa": OlmoSdpaAttention, -} + return attn_output, attn_weights class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - - self.self_attn = OLMO_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) self.mlp = OlmoMLP(config) self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) - # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward - # TODO(joao): add me back asap :) def forward( self, hidden_states: torch.Tensor, @@ -561,33 +239,15 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -595,6 +255,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -606,16 +267,77 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs +class OlmoRotaryEmbedding(nn.Module): + def __init__( + self, + config: OlmoConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + OLMO_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -637,7 +359,6 @@ def forward( "The bare Olmo Model outputting raw hidden-states without any specific head on top.", OLMO_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmo class OlmoPreTrainedModel(PreTrainedModel): config_class = OlmoConfig base_model_prefix = "model" @@ -759,6 +480,7 @@ def __init__(self, config: OlmoConfig): [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = OlmoLayerNorm(config.hidden_size) + self.rotary_emb = OlmoRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -771,20 +493,19 @@ def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) - # copied from transformers.models.llama.modeling_llama.LlamaModel.forward - # TODO(joao): add me back asap :) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -805,25 +526,15 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -831,15 +542,16 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -853,6 +565,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -863,13 +576,12 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -879,20 +591,14 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -959,7 +665,6 @@ def _update_causal_mask( return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1016,9 +721,12 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) @@ -1049,13 +757,12 @@ def get_decoder(self): @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1064,7 +771,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1085,8 +792,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, OlmoForCausalLM - >>> model = OlmoForCausalLM.from_pretrained("allenai/OLMo-1B-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf") + >>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1094,9 +801,8 @@ def forward( >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' - ``` - """ + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1115,6 +821,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1123,7 +830,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py new file mode 100644 index 00000000000000..2a43e6f9c75d05 --- /dev/null +++ b/src/transformers/models/olmo/modular_olmo.py @@ -0,0 +1,126 @@ +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...cache_utils import Cache +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, + apply_rotary_pos_emb, + eager_attention_forward, +) +from .configuration_olmo import OlmoConfig + + +logger = logging.get_logger(__name__) + + +class OlmoLayerNorm(nn.Module): + """LayerNorm but with no learnable weight or bias.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.normalized_shape = (hidden_size,) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to( + orig_dtype + ) + + +class OlmoMLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class OlmoAttention(LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class OlmoDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: OlmoConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.input_layernorm = OlmoLayerNorm(config.hidden_size) + self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) + self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) + + +class OlmoModel(LlamaModel): + def __init__(self, config: OlmoConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = OlmoLayerNorm(config.hidden_size) + + +class OlmoForCausalLM(LlamaForCausalLM): + pass diff --git a/src/transformers/models/olmo2/configuration_olmo2.py b/src/transformers/models/olmo2/configuration_olmo2.py index 144520f87ed7f9..83c3263de1f552 100644 --- a/src/transformers/models/olmo2/configuration_olmo2.py +++ b/src/transformers/models/olmo2/configuration_olmo2.py @@ -5,6 +5,7 @@ # modular_olmo2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 6c35587f1f14fc..49ae798e7f1101 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -4,35 +4,31 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_olmo2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch -from torch import nn +import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_olmo2 import Olmo2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "Olmo2Config" @@ -56,66 +52,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo2 -# TODO(joao): add me back asap :) -class Olmo2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Olmo2LinearScalingRotaryEmbedding(Olmo2RotaryEmbedding): - """Olmo2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class Olmo2DynamicNTKScalingRotaryEmbedding(Olmo2RotaryEmbedding): - """Olmo2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - - cos, sin = super().forward(x, position_ids) - return cos, sin - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -162,180 +98,81 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Olmo2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo2 - # TODO(joao): add me back asap :) def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - self._init_rope() - self.q_norm = Olmo2RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) - self.k_norm = Olmo2RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = Olmo2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = Olmo2LinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = Olmo2DynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Olmo2FlashAttention2(Olmo2Attention): - """ - Olmo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - - OLMo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states)) key_states = self.k_norm(self.k_proj(hidden_states)) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -343,135 +180,30 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (OlmoRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Olmo2SdpaAttention(Olmo2Attention): - """ - Olmo2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Olmo2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Olmo2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Olmo2Model is using Olmo2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - bsz, q_len, _ = hidden_states.size() - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, attn_weights class Olmo2MLP(nn.Module): @@ -486,29 +218,20 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -OLMO2_ATTENTION_CLASSES = { - "eager": Olmo2Attention, - "flash_attention_2": Olmo2FlashAttention2, - "sdpa": Olmo2SdpaAttention, -} + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class Olmo2DecoderLayer(nn.Module): def __init__(self, config: Olmo2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - - self.self_attn = OLMO2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) self.mlp = Olmo2MLP(config) self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward - # TODO(joao): add me back asap :) def forward( self, hidden_states: torch.Tensor, @@ -518,31 +241,13 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -550,6 +255,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -564,11 +270,75 @@ def forward( outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) + return outputs +class Olmo2RotaryEmbedding(nn.Module): + def __init__( + self, + config: Olmo2Config, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + OLMO2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -711,6 +481,7 @@ def __init__(self, config: Olmo2Config): [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Olmo2RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -723,20 +494,19 @@ def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) - # copied from transformers.models.llama.modeling_llama.LlamaModel.forward - # TODO(joao): add me back asap :) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -757,25 +527,15 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -783,15 +543,16 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -805,6 +566,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -815,13 +577,12 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -831,18 +592,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -966,11 +722,14 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO2,Llama->Olmo2 +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} - def __init__(self, config: Olmo2Config): + def __init__(self, config): super().__init__(config) self.model = Olmo2Model(config) self.vocab_size = config.vocab_size @@ -999,13 +758,12 @@ def get_decoder(self): @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1014,7 +772,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1035,8 +793,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, Olmo2ForCausalLM - >>> model = Olmo2ForCausalLM.from_pretrained("allenai/Olmo2-1B-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo2-1B-hf") + >>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1044,9 +802,8 @@ def forward( >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' - ``` - """ + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1065,6 +822,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1073,7 +831,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 393d17c59c1a8b..5f119170804466 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -1,30 +1,23 @@ -import math -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch from torch import nn from ...cache_utils import Cache +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging -from ..llama.modeling_llama import LlamaRMSNorm +from ...utils import logging +from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward from ..olmo.configuration_olmo import OlmoConfig from ..olmo.modeling_olmo import ( OlmoAttention, OlmoDecoderLayer, - OlmoFlashAttention2, OlmoForCausalLM, OlmoModel, - OlmoPreTrainedModel, - OlmoSdpaAttention, apply_rotary_pos_emb, - repeat_kv, ) -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - logger = logging.get_logger(__name__) @@ -170,112 +163,30 @@ class Olmo2RMSNorm(LlamaRMSNorm): class Olmo2Attention(OlmoAttention): def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx=layer_idx) - self.q_norm = Olmo2RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) - self.k_norm = Olmo2RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) + self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Olmo2FlashAttention2(OlmoFlashAttention2, Olmo2Attention): - """ - OLMo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - Olmo2Attention.__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states)) key_states = self.k_norm(self.k_proj(hidden_states)) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -283,129 +194,30 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (OlmoRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Olmo2SdpaAttention(OlmoSdpaAttention, Olmo2Attention): - # Adapted from Olmo2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Olmo2Model is using Olmo2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - bsz, q_len, _ = hidden_states.size() - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, attn_weights # The OLMo2 layers are identical to those of the OLMo model except: @@ -416,6 +228,7 @@ def __init__(self, config: Olmo2Config, layer_idx: int): super().__init__(config, layer_idx=layer_idx) self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) del self.input_layernorm def forward( @@ -427,12 +240,13 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -440,6 +254,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -454,13 +269,8 @@ def forward( outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs - -class Olmo2PreTrainedModel(OlmoPreTrainedModel): - pass + return outputs # The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of @@ -468,22 +278,20 @@ class Olmo2PreTrainedModel(OlmoPreTrainedModel): class Olmo2Model(OlmoModel): def __init__(self, config: Olmo2Config): super().__init__(config) + self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layers = nn.ModuleList( [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # The heads now only need to redefine the model inside to the correct `RobertaModel` class Olmo2ForCausalLM(OlmoForCausalLM): - def __init__(self, config: Olmo2Config): - super().__init__(config) - self.model = Olmo2Model(config) + pass __all__ = [ "Olmo2Config", "Olmo2ForCausalLM", "Olmo2Model", - "Olmo2PreTrainedModel", + "Olmo2PreTrainedModel", # noqa: F822 ] diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 4398e2f5c9a1fd..fa3c2f3cd4d11b 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -160,40 +160,18 @@ def extra_repr(self): class OlmoeRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: OlmoeConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[OlmoeConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`OlmoeRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -293,7 +271,8 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.repeat_kv @@ -422,7 +401,6 @@ class OlmoeFlashAttention2(OlmoeAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index e4ef510f099d66..3350ae1a23c2b7 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -257,7 +257,6 @@ class OptFlashAttention2(OPTAttention): attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 884ee4d86aafcc..8d3c20b9ace717 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -59,40 +59,18 @@ class PersimmonRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: PersimmonConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[PersimmonConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`PersimmonRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8e60798e857f03..477896decd5318 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1,33 +1,19 @@ -# coding=utf-8 -# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""PyTorch Phi model.""" - -import math -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi/modular_phi.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint -from packaging import version -from torch import nn -from torch.nn import CrossEntropyLoss +import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -35,119 +21,25 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - get_torch_version, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_phi import PhiConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "microsoft/phi-1" +_CHECKPOINT_FOR_DOC = "meta-phi/Phi-2-7b-hf" _CONFIG_FOR_DOC = "PhiConfig" -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi -class PhiRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[PhiConfig] = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`PhiRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -155,7 +47,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -183,23 +74,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi -class PhiMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -212,190 +86,79 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class PhiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None): + def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) - + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( - config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) self.k_layernorm = nn.LayerNorm( - config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True - ) - - self.rotary_emb = PhiRotaryEmbedding(config=self.config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.qk_layernorm: - query_states = self.q_layernorm(query_states) - key_states = self.k_layernorm(key_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - - # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_ndims], - query_states[..., self.rotary_ndims :], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_ndims], - key_states[..., self.rotary_ndims :], - ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) - - # [batch_size, seq_length, num_heads, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) - - if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "partial_rotation_size": self.rotary_ndims, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow - attn_weights = torch.matmul( - query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights += causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.dense(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class PhiFlashAttention2(PhiAttention): - """ - Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # PhiFlashAttention2 attention does not support output_attentions - - output_attentions = False + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if self.qk_layernorm: query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings - # Partial rotary embedding query_rot, query_pass = ( query_states[..., : self.rotary_ndims], @@ -413,199 +176,55 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "partial_rotation_size": self.rotary_ndims, - "cache_position": cache_position, - } + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_dropout = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - - if query_states.dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=attn_dropout, - softmax_scale=None, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.dense(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class PhiSdpaAttention(PhiAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") - - """ - SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from PhiAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " - "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " - "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " - 'be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.qk_layernorm: - query_states = self.q_layernorm(query_states) - key_states = self.k_layernorm(key_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - - # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_ndims], - query_states[..., self.rotary_ndims :], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_ndims], - key_states[..., self.rotary_ndims :], - ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) - - # [batch_size, seq_length, num_heads, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) - - if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "partial_rotation_size": self.rotary_ndims, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.dense(attn_output) - - return attn_output, None, past_key_value +class PhiMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) -PHI_ATTENTION_CLASSES = { - "eager": PhiAttention, - "flash_attention_2": PhiFlashAttention2, - "sdpa": PhiSdpaAttention, -} + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states class PhiDecoderLayer(nn.Module): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() - self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.self_attn = PhiAttention(config, layer_idx=layer_idx) self.mlp = PhiMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.resid_dropout = nn.Dropout(config.resid_pdrop) @@ -615,45 +234,19 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): - input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range - `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - attn_outputs, self_attn_weights, present_key_value = self.self_attn( + attn_outputs, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -662,6 +255,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) attn_outputs = self.resid_dropout(attn_outputs) @@ -672,12 +266,74 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs +class PhiRotaryEmbedding(nn.Module): + def __init__( + self, + config: PhiConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + PHI_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -704,12 +360,12 @@ class PhiPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PhiDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True - _supports_static_cache = True _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range @@ -816,17 +472,14 @@ def __init__(self, config: PhiConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.embed_dropout = nn.Dropout(config.embd_pdrop) self.layers = nn.ModuleList( [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.rotary_emb = PhiRotaryEmbedding(config=config) - - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" - self.gradient_checkpointing = False + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + # Initialize weights and apply final processing self.post_init() @@ -842,54 +495,43 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -897,7 +539,7 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - inputs_embeds = self.embed_dropout(inputs_embeds) + inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -906,9 +548,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -918,9 +559,9 @@ def forward( hidden_states, causal_mask, position_ids, + past_key_values, output_attentions, use_cache, - past_key_values, cache_position, position_embeddings, ) @@ -934,36 +575,28 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) - hidden_states = self.final_layernorm(hidden_states) + hidden_states = self.final_layernorm(hidden_states) # diff with Llama # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1030,7 +663,6 @@ def _update_causal_mask( return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1087,40 +719,37 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True def __init__(self, config): super().__init__(config) self.model = PhiModel(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings def get_input_embeddings(self): return self.model.embed_tokens - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings def set_input_embeddings(self, value): self.model.embed_tokens = value - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings def get_output_embeddings(self): return self.lm_head - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder def set_decoder(self, decoder): self.model = decoder - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder def get_decoder(self): return self.model @@ -1131,7 +760,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1140,7 +769,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1161,18 +790,17 @@ def forward( ```python >>> from transformers import AutoTokenizer, PhiForCausalLM - >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1") - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1") + >>> model = PhiForCausalLM.from_pretrained("meta-phi/Phi-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi/Phi-2-7b-hf") - >>> prompt = "This is an example script ." + >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str' + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1191,6 +819,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1199,7 +828,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1216,7 +845,7 @@ def forward( @add_start_docstrings( """ - The PhiModel with a sequence classification head on top (linear layer). + The Phi Model transformer with a sequence classification head on top (linear layer). [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. @@ -1229,7 +858,6 @@ def forward( """, PHI_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs class PhiForSequenceClassification(PhiPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1268,7 +896,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - model_outputs = self.model( + transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1279,7 +907,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - hidden_states = model_outputs[0] + hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: @@ -1307,44 +935,48 @@ def forward( loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) if not return_dict: - output = (pooled_logits,) + model_outputs[1:] + output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, - past_key_values=model_outputs.past_key_values, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ - PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. + The Phi Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. """, PHI_START_DOCSTRING, ) -# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs class PhiForTokenClassification(PhiPreTrainedModel): - def __init__(self, config: PhiConfig): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.model = PhiModel(config) - if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout - elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1354,16 +986,16 @@ def __init__(self, config: PhiConfig): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1372,38 +1004,32 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - model_outputs = self.model( + outputs = self.model( input_ids, - past_key_values=past_key_values, attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - - hidden_states = model_outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) loss = None if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - batch_size, seq_length = labels.shape - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) - ) + loss = self.loss_function(logits, labels, self.config) if not return_dict: - output = (logits,) + model_outputs[2:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py new file mode 100644 index 00000000000000..0faa4629f1a768 --- /dev/null +++ b/src/transformers/models/phi/modular_phi.py @@ -0,0 +1,295 @@ +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...cache_utils import Cache, DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..clip.modeling_clip import CLIPMLP +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + apply_rotary_pos_emb, + eager_attention_forward, # copied from Llama +) +from .configuration_phi import PhiConfig + + +logger = logging.get_logger(__name__) + + +class PhiAttention(LlamaAttention): + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + del self.o_proj + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.dense(attn_output) + return attn_output, attn_weights + + +class PhiMLP(CLIPMLP): + pass + + +class PhiDecoderLayer(nn.Module): + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__() + self.self_attn = PhiAttention(config, layer_idx=layer_idx) + self.mlp = PhiMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + attn_outputs = self.resid_dropout(attn_outputs) + + feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class PhiModel(LlamaModel): + def __init__(self, config: PhiConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + del self.norm + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) # diff with Llama + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class PhiForCausalLM(LlamaForCausalLM): + pass + + +class PhiForSequenceClassification(LlamaForSequenceClassification): + pass + + +class PhiForTokenClassification(LlamaForTokenClassification): + pass diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index bae3f6d4cdaeaa..908fd982b9c73c 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -74,7 +74,8 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +# copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +# TODO cyril: modular class Phi3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -431,7 +432,6 @@ class Phi3FlashAttention2(Phi3Attention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -550,8 +550,8 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 -# TODO @Arthur no longer copied from LLama after static cache +# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 +# TODO cyril: modular class Phi3SdpaAttention(Phi3Attention): """ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 82763ccea62e4c..cd54b226e1d85c 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -186,7 +186,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -912,10 +911,12 @@ class PhimoePreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PhimoeDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index b65fbd634ba789..03886d4a528478 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -216,6 +216,7 @@ def forward( class PixtralMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -223,8 +224,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f4875386253c43..36fb1ddf1390ac 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1,36 +1,19 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen2 model.""" - -import math -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -39,140 +22,41 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_qwen2 import Qwen2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B" +_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" _CONFIG_FOR_DOC = "Qwen2Config" -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Qwen2Config] = None, - ): +class Qwen2MLP(nn.Module): + def __init__(self, config): super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -180,7 +64,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -208,22 +91,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -236,366 +103,160 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - + sliding_window = None if ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): sliding_window = self.config.sliding_window - else: - sliding_window = None - attn_output = _flash_attention_forward( + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen2DecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.sliding_window and config._attn_implementation != "flash_attention_2": logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -604,6 +265,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -614,16 +276,77 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs +class Qwen2RotaryEmbedding(nn.Module): + def __init__( + self, + config: Qwen2Config, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + QWEN2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -650,7 +373,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True @@ -690,7 +413,7 @@ def _init_weights(self, module): Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] @@ -765,11 +488,10 @@ def __init__(self, config: Qwen2Config): self.layers = nn.ModuleList( [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) - self.gradient_checkpointing = False + # Initialize weights and apply final processing self.post_init() @@ -785,54 +507,43 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -848,9 +559,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -876,13 +586,11 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -892,20 +600,14 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -924,30 +626,21 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: + if using_static_cache: target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -964,8 +657,6 @@ def _update_causal_mask( device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, ) if ( @@ -977,12 +668,12 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2 def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -991,8 +682,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, - config: Qwen2Config, - past_key_values: Cache, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1000,11 +690,13 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1013,10 +705,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. - config (`Qwen2Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1026,30 +714,25 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) + return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1088,7 +771,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1097,7 +780,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1118,8 +801,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1129,7 +812,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1148,6 +830,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1156,7 +839,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1205,10 +888,10 @@ def set_input_embeddings(self, value): @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1260,27 +943,8 @@ def forward( loss = None if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1301,7 +965,6 @@ def forward( """, QWEN2_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 class Qwen2ForTokenClassification(Qwen2PreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1390,24 +1053,22 @@ def forward( """, QWEN2_START_DOCSTRING, ) -# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2, MISTRAL->QWEN2 class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): - base_model_prefix = "model" + base_model_prefix = "transformer" - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2 def __init__(self, config): super().__init__(config) - self.model = Qwen2Model(config) + self.transformer = Qwen2Model(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.model.embed_tokens + return self.transformer.embed_tokens def set_input_embeddings(self, value): - self.model.embed_tokens = value + self.transformer.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) def forward( @@ -1436,7 +1097,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py new file mode 100644 index 00000000000000..718abd01090c2b --- /dev/null +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -0,0 +1,134 @@ +from typing import Callable, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaMLP, + LlamaModel, + apply_rotary_pos_emb, + eager_attention_forward, +) +from .configuration_qwen2 import Qwen2Config + + +logger = logging.get_logger(__name__) + + +class Qwen2MLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class Qwen2Attention(LlamaAttention): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + +class Qwen2Model(LlamaModel): + pass + + +class Qwen2ForCausalLM(LlamaForCausalLM): + pass + + +class Qwen2ForSequenceClassification(LlamaForSequenceClassification): + pass + + +class Qwen2ForTokenClassification(LlamaForTokenClassification): + pass + + +class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering): + pass diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index ce0e427048cf23..44a5b5ce315570 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -223,7 +223,6 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index a1e36b8ad7bc20..1ce41509a5c0d1 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -169,40 +169,18 @@ def extra_repr(self): class Qwen2MoeRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: Qwen2MoeConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Qwen2MoeConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2MoeRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -318,7 +296,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe +# copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe +# no longer copied after attention refactors class Qwen2MoeAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -419,7 +398,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe +# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe +# TODO cyril: modular class Qwen2MoeFlashAttention2(Qwen2MoeAttention): """ Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` @@ -429,7 +409,6 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): config.max_window_layers layers. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -530,7 +509,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe +# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe +# TODO cyril: modular class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -1578,11 +1558,10 @@ def forward( class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): base_model_prefix = "model" - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2Moe def __init__(self, config): super().__init__(config) - self.model = Qwen2MoeModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = Qwen2MoeModel(config) # diff with Llama: transformer->model # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index dce0702b081942..10c9b1638548ce 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -460,6 +460,7 @@ def extra_repr(self): class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -467,8 +468,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.repeat_kv diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 2b3cf7eb0cb82e..74fc2085c36519 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -77,7 +77,6 @@ def __init__(self, dim, base=10000, device=None): self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() - # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) @@ -185,7 +184,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + cos, sin = self.rotary_emb(value_states, position_ids) # Partial rotary embedding query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 8638d93385843d..1959d21e1d5d94 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -563,7 +563,6 @@ class SEWFlashAttention2(SEWAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index a42bcd0e17461e..9a2dfe013716a7 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -438,7 +438,6 @@ class SiglipFlashAttention2(SiglipAttention): is_causal = False - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 0ce550697e79ab..88dc437cdcb91d 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -65,40 +65,18 @@ class StableLmRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, + config: StableLmConfig, device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[StableLmConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`StableLmRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -189,6 +167,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class StableLmMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -196,8 +175,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class StableLmLayerNormPerHead(nn.Module): @@ -472,7 +452,6 @@ class StableLmFlashAttention2(StableLmAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 8047e23bb05bd8..3b4fdbcb81ccc4 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -24,8 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -34,6 +33,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -41,115 +41,24 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_starcoder2 import Starcoder2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b" _CONFIG_FOR_DOC = "Starcoder2Config" -class Starcoder2RotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Starcoder2Config] = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Starcoder2RotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class Starcoder2MLP(nn.Module): def __init__(self, config: Starcoder2Config): super().__init__() @@ -213,309 +122,111 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Starcoder2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.use_bias = config.use_bias - self.is_causal = True + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) self.residual_dropout = config.residual_dropout - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.use_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias) - def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights += causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Starcoder2FlashAttention2(Starcoder2Attention): - """ - Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reshape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Starcoder2SdpaAttention(Starcoder2Attention): - """ - Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Starcoder2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Starcoder2Model is using Starcoder2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - # The difference with Mistral is that here it uses dropout - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - return attn_output, None, past_key_value - + attn_output = nn.functional.dropout( + attn_output, p=self.residual_dropout, training=self.training + ) # diff with Llama -STARCODER2_ATTENTION_CLASSES = { - "eager": Starcoder2Attention, - "flash_attention_2": Starcoder2FlashAttention2, - "sdpa": Starcoder2SdpaAttention, -} + return attn_output, attn_weights class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - - self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - + self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx) self.mlp = Starcoder2MLP(config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -524,41 +235,19 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -567,6 +256,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -577,16 +267,77 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs +class Starcoder2RotaryEmbedding(nn.Module): + def __init__( + self, + config: Starcoder2Config, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + STARCODER2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -613,7 +364,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Starcoder2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True @@ -653,7 +404,7 @@ def _init_weights(self, module): Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] @@ -728,12 +479,11 @@ def __init__(self, config: Starcoder2Config): self.layers = nn.ModuleList( [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = config._attn_implementation self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.rotary_emb = Starcoder2RotaryEmbedding(config=config) - self.gradient_checkpointing = False self.embedding_dropout = config.embedding_dropout + # Initialize weights and apply final processing self.post_init() @@ -749,54 +499,43 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -805,7 +544,9 @@ def forward( ) hidden_states = inputs_embeds - hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.embedding_dropout, training=self.training + ) # main diff with Llama # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -813,41 +554,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -857,18 +582,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -879,6 +599,14 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Starcoder2. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None @@ -1013,6 +741,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1051,7 +782,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1060,7 +791,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1081,8 +812,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, Starcoder2ForCausalLM - >>> model = Starcoder2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = Starcoder2ForCausalLM.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1092,7 +823,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1111,6 +841,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1119,7 +850,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 013c8e472b325d..32d64cd167ba50 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -19,8 +19,7 @@ # limitations under the License. """PyTorch Starcoder2 model.""" -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -28,40 +27,32 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, ) -from ...utils import ( - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, -) -from ..llama.modeling_llama import ( - LlamaForSequenceClassification, - LlamaForTokenClassification, - LlamaRotaryEmbedding, +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import add_start_docstrings_to_model_forward, logging +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, apply_rotary_pos_emb, - repeat_kv, + eager_attention_forward, ) -from ..qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM, Qwen2Model, Qwen2PreTrainedModel from .configuration_starcoder2 import Starcoder2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Starcoder2Config" _CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b" -class Starcoder2RotaryEmbedding(LlamaRotaryEmbedding): - pass - - class Starcoder2MLP(nn.Module): def __init__(self, config: Starcoder2Config): super().__init__() @@ -79,332 +70,90 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -class Starcoder2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - +class Starcoder2Attention(MistralAttention): def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.use_bias = config.use_bias - self.is_causal = True - self.attention_dropout = config.attention_dropout self.residual_dropout = config.residual_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.use_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights += causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - -class Starcoder2FlashAttention2(Starcoder2Attention): - """ - Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reshape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Starcoder2SdpaAttention(Starcoder2Attention): - """ - Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Starcoder2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Starcoder2Model is using Starcoder2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - # The difference with Mistral is that here it uses dropout - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - return attn_output, None, past_key_value - + attn_output = nn.functional.dropout( + attn_output, p=self.residual_dropout, training=self.training + ) # diff with Llama -STARCODER2_ATTENTION_CLASSES = { - "eager": Starcoder2Attention, - "flash_attention_2": Starcoder2FlashAttention2, - "sdpa": Starcoder2SdpaAttention, -} + return attn_output, attn_weights -class Starcoder2DecoderLayer(Qwen2DecoderLayer, nn.Module): +class Starcoder2DecoderLayer(MistralDecoderLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): - nn.Module.__init__(self) - self.hidden_size = config.hidden_size - - self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - + super().__init__(self) + self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx) self.mlp = Starcoder2MLP(config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) -class Starcoder2PreTrainedModel(Qwen2PreTrainedModel): - pass - - STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined -class Starcoder2Model(Qwen2Model): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Starcoder2DecoderLayer`] - - Args: - config: Starcoder2Config - """ - +class Starcoder2Model(MistralModel): def __init__(self, config: Starcoder2Config): super().__init__(config) - self.embedding_dropout = config.embedding_dropout + self.layers = nn.ModuleList( + [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.embedding_dropout = config.embedding_dropout @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) def forward( @@ -412,54 +161,43 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -468,7 +206,9 @@ def forward( ) hidden_states = inputs_embeds - hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.embedding_dropout, training=self.training + ) # main diff with Llama # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -476,41 +216,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -520,36 +244,31 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() -class Starcoder2ForCausalLM(Qwen2ForCausalLM): +class Starcoder2ForCausalLM(MistralForCausalLM): pass -class Starcoder2ForSequenceClassification(LlamaForSequenceClassification): +class Starcoder2ForSequenceClassification(MistralForSequenceClassification): pass -class Starcoder2ForTokenClassification(LlamaForTokenClassification): +class Starcoder2ForTokenClassification(MistralForTokenClassification): pass __all__ = [ "Starcoder2ForCausalLM", "Starcoder2Model", - "Starcoder2PreTrainedModel", + "Starcoder2PreTrainedModel", # noqa: F822 "Starcoder2ForSequenceClassification", "Starcoder2ForTokenClassification", ] diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 6ce5e77706d358..d1496432279527 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -595,7 +595,6 @@ class UniSpeechFlashAttention2(UniSpeechAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 52d82ea739426b..49551b73577ad7 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -612,7 +612,6 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index bf1bb7746ce802..ca743e1eaef3af 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -659,7 +659,6 @@ class Wav2Vec2FlashAttention2(Wav2Vec2Attention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ce3df3e16707e5..fb01823a29c017 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -354,7 +354,6 @@ class WhisperFlashAttention2(WhisperAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index dee7f898fcf93a..3b7348eadd4785 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -312,7 +312,6 @@ class ZambaFlashAttention2(ZambaAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -774,6 +773,7 @@ def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache class ZambaMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -781,8 +781,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class ZambaAttentionDecoderLayer(nn.Module): diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 64ebedcb45984b..1c4051f2e2645c 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -733,15 +733,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 129bd346a10d8f..3ad46a92bc0938 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -453,11 +453,9 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Falcon def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) @@ -470,11 +468,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = FalconRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = FalconRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -482,13 +476,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = FalconRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = FalconRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -501,13 +490,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = FalconRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = FalconRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 012444b472c0fc..88ccdc8ee45a2d 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -507,7 +507,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin else {} ) all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez test_missing_keys = False test_model_parallel = True diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index ca9fbb225c6d87..6d5e081d50b152 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -366,12 +366,8 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->GPTNeoX, rope_theta->rotary_emb_base def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) @@ -384,11 +380,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = GPTNeoXRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rotary_emb_base, - ).to(torch_device) + original_rope = GPTNeoXRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -396,13 +388,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = GPTNeoXRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rotary_emb_base, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = GPTNeoXRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -415,13 +402,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = GPTNeoXRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rotary_emb_base, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = GPTNeoXRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index ae8c91f29d4d46..83e125c07c15bc 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -362,15 +362,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTesterMixin, unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 0790de4e133b97..78e42e6ba71f2f 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -308,7 +308,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer @@ -571,10 +571,6 @@ def test_use_flash_attention_2_true(self): if not has_flash: raise ValueError("The flash model should have flash attention layers") - @unittest.skip("Broken by the loss update will fix soon @ArthurZucker") - def test_torch_fx_output_loss(self, *args, **kwargs): - pass - @require_torch_gpu class LlamaIntegrationTest(unittest.TestCase): diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index c5ea050edf92ef..d9e6b9d7bfe7c0 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -316,7 +316,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 931bb1f17beccf..9abbf444d0b0b4 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -314,7 +314,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 54ee49b65343ee..e783cea95a63b3 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -417,12 +417,9 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Persimmon + # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Persimmon def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) @@ -435,11 +432,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = PersimmonRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = PersimmonRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -447,13 +440,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = PersimmonRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -466,13 +454,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = PersimmonRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index df5278cb34e315..c7b59d278e4fe6 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -396,12 +396,9 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi + # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Phi def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) @@ -414,11 +411,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = PhiRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = PhiRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -426,13 +419,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = PhiRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = PhiRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -445,13 +433,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = PhiRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = PhiRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 6c32a66e03626c..ecfa9189d12e62 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -327,7 +327,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 42b521e518e22e..4806ec2c72d339 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -206,15 +206,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index abc7b57919b083..21d11047ff1be8 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -352,7 +352,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index 7dcb7c406ae287..897d4b056f1977 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -500,15 +500,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase): diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index bfab01578229ec..c8aa55399035d2 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -402,12 +402,9 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->StableLm + # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->StableLm def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) @@ -420,11 +417,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = StableLmRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = StableLmRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -432,13 +425,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = StableLmRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -451,13 +439,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = StableLmRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 77e2a19fea4861..2b517034bffb15 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -441,15 +441,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3aaf18c945451f..1d7e995f80756c 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -119,6 +119,7 @@ from torch import nn from transformers import MODEL_MAPPING, AdaptiveEmbedding + from transformers.cache_utils import DynamicCache from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -1285,6 +1286,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ) for i in range(model.config.num_hidden_layers) ) + empty_pkv = ( + DynamicCache.from_legacy_cache(empty_pkv) + if model_class._supports_cache_class + else empty_pkv + ) cache_length = 9 cache_shape = (batch_size, num_heads, cache_length, head_dim) @@ -1295,6 +1301,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ) for i in range(model.config.num_hidden_layers) ) + non_empty_pkv = ( + DynamicCache.from_legacy_cache(non_empty_pkv) + if model_class._supports_cache_class + else non_empty_pkv + ) inps = copy.deepcopy(inputs_to_test[0]) @@ -2471,7 +2482,7 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla return new_tf_outputs, new_pt_outputs # Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs - def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way. Args: @@ -2527,6 +2538,8 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))]) for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes): + if isinstance(pt_output, DynamicCache): + pt_output = pt_output.to_legacy_cache() self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(tf_outputs, tf.Tensor): @@ -2702,7 +2715,7 @@ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): """ Args: model_class: The class of the model that is currently testing. For example, ..., etc. @@ -2712,7 +2725,6 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, n Currently unused, but in the future, we could use this information to make the error message clearer by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. """ - self.assertEqual(type(name), str) if attributes is not None: self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`") @@ -2757,6 +2769,8 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, n attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))]) for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes): + if isinstance(pt_output, DynamicCache): + pt_output = pt_output.to_legacy_cache() self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(fx_outputs, jnp.ndarray): @@ -3881,15 +3895,6 @@ def test_sdpa_can_dispatch_non_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): """ @@ -3942,15 +3947,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and any(module_attn == "sdpa" for module_attn in [text_attn, vision_attn]): - raise ValueError("The SDPA model should have SDPA attention layers") - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa def test_eager_matches_sdpa_inference(self, torch_dtype: str): diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index c7d098be3ea8f2..bfe1648de049e1 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -23,6 +23,7 @@ import transformers from transformers import is_flax_available, is_torch_available +from transformers.cache_utils import DynamicCache from transformers.models.auto import get_values from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging @@ -180,7 +181,7 @@ def recursive_check(tuple_object, dict_object): check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) # (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs) - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): """ Args: model_class: The class of the model that is currently testing. For example, ..., etc. @@ -190,7 +191,6 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, n Currently unused, but in the future, we could use this information to make the error message clearer by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. """ - self.assertEqual(type(name), str) if attributes is not None: self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`") @@ -235,6 +235,8 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, n attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))]) for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes): + if isinstance(pt_output, DynamicCache): + pt_output = pt_output.to_legacy_cache() self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(fx_outputs, jnp.ndarray): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index eb328d83e9e7a4..9dc712ab67b682 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -484,7 +484,7 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla return new_tf_outputs, new_pt_outputs - def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way. Args: @@ -495,6 +495,7 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element being a named field in the output. """ + from transformers.cache_utils import DynamicCache self.assertEqual(type(name), str) if attributes is not None: @@ -540,6 +541,8 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))]) for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes): + if isinstance(pt_output, DynamicCache): + pt_output = pt_output.to_legacy_cache() self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(tf_outputs, tf.Tensor): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 31c0d01af776ac..383f0cbe60e1c9 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -563,32 +563,17 @@ def test_model_from_pretrained_attn_implementation(self): if is_flash_attn_2_available(): attn_implementation_available.append("flash_attention_2") - mistral_attention_classes = { - "eager": "MistralAttention", - "sdpa": "MistralSdpaAttention", - "flash_attention_2": "MistralFlashAttention2", - } for requested_attn_implementation in attn_implementation_available: model = AutoModelForCausalLM.from_pretrained( TINY_MISTRAL, attn_implementation=requested_attn_implementation ) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) config = AutoConfig.from_pretrained(TINY_MISTRAL) model = AutoModelForCausalLM.from_pretrained( TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation ) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) def test_model_from_config_attn_implementation(self): # test that the model can be instantiated with attn_implementation of either @@ -602,11 +587,6 @@ def test_model_from_config_attn_implementation(self): if is_flash_attn_2_available(): attn_implementation_available.append("flash_attention_2") - mistral_attention_classes = { - "eager": "MistralAttention", - "sdpa": "MistralSdpaAttention", - "flash_attention_2": "MistralFlashAttention2", - } for requested_attn_implementation in attn_implementation_available: config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) # Ensure the config was set correctly @@ -614,11 +594,6 @@ def test_model_from_config_attn_implementation(self): self.assertEqual(config._attn_implementation_internal, requested_attn_implementation) model = AutoModelForCausalLM.from_config(config) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) config = AutoConfig.from_pretrained(TINY_MISTRAL) # When the config is not set, the default is "eager" @@ -626,11 +601,6 @@ def test_model_from_config_attn_implementation(self): self.assertEqual(config._attn_implementation_internal, None) model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) # Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz") @@ -638,11 +608,6 @@ def test_model_from_config_attn_implementation(self): self.assertEqual(config._attn_implementation_internal, "foo-bar-baz") model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) def test_torch_dtype_byte_sizes(self): torch_dtypes_and_bytes = [ diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index a125387ff29268..420d6e6a2475d1 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -307,6 +307,10 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "backbone_config", "use_timm_backbone", "backbone_kwargs", + # rope attributes may not appear directly in the modeling but are used + "rope_theta", + "partial_rotary_factor", + "pretraining_tp", ] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] From 9a94dfe1239ddfb8010a654aa1e677d56c01eee0 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Wed, 18 Dec 2024 18:59:07 +0100 Subject: [PATCH 03/23] feat: add `benchmarks_entrypoint.py` (#34495) * feat: add `benchmarks_entrypoint.py` Adding `benchmarks_entrypoint.py` file, which will be run from the benchmarks CI. This python script will list all python files from the `benchmark/` folder and run the included `run_benchmark` function, allowing people to add new benchmarks scripts. * feat: add `MetricsRecorder` * feat: update dashboard * fix: add missing arguments to `MetricsRecorder` * feat: update dash & add datasource + `default.yml` * fix: move responsibility to create `MetricsRecorder` in bench script * fix: update incorrect datasource UID * fix: incorrect variable values * debug: benchmark entrypoint script * refactor: update log level * fix: update broken import * feat: add debug log in `MetricsRecorder` * debug: set log level to debug * fix: set connection `autocommit` to `True` --- .github/workflows/benchmark.yml | 2 +- benchmark/README.md | 49 ++++++++++ benchmark/benchmarks_entrypoint.py | 144 ++++++++++++++++++++++++++++ benchmark/default.yml | 10 ++ benchmark/grafana_dashboard.json | 145 ++++++++++++++++------------- benchmark/grafana_datasource.yaml | 17 ++++ benchmark/init_db.sql | 2 +- benchmark/llama.py | 134 +++++++------------------- 8 files changed, 334 insertions(+), 169 deletions(-) create mode 100644 benchmark/README.md create mode 100644 benchmark/benchmarks_entrypoint.py create mode 100644 benchmark/default.yml create mode 100644 benchmark/grafana_datasource.yaml diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index eaa4b3b2f82456..1bbd1c1e94d08c 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -63,7 +63,7 @@ jobs: commit_id=$GITHUB_SHA fi commit_msg=$(git show -s --format=%s | cut -c1-70) - python3 benchmark/llama.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg" + python3 benchmark/benchmarks_entrypoint.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg" env: HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} # Enable this to see debug logs diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000000000..a827da444f0801 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,49 @@ +# Benchmarks + +You might want to add new benchmarks. + +You will need to define a python function named `run_benchmark` in your python file and the file must be located in this `benchmark/` directory. + +The expected function signature is the following: + +```py +def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100): +``` + +## Writing metrics to the database + +`MetricRecorder` is thread-safe, in the sense of the python [`Thread`](https://docs.python.org/3/library/threading.html#threading.Thread). This means you can start a background thread to do the readings on the device measurements while not blocking the main thread to execute the model measurements. + +cf [`llama.py`](./llama.py) to see an example of this in practice. + +```py +from benchmarks_entrypoint import MetricsRecorder +import psycopg2 + +def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100): + metrics_recorder = MetricsRecorder(psycopg2.connect("dbname=metrics"), logger, branch, commit_id, commit_msg) + benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id}) + # To collect device measurements + metrics_recorder.collect_device_measurements( + benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes + ) + # To collect your model measurements + metrics_recorder.collect_model_measurements( + benchmark_id, + { + "model_load_time": model_load_time, + "first_eager_forward_pass_time_secs": first_eager_fwd_pass_time, + "second_eager_forward_pass_time_secs": second_eager_fwd_pass_time, + "first_eager_generate_time_secs": first_eager_generate_time, + "second_eager_generate_time_secs": second_eager_generate_time, + "time_to_first_token_secs": time_to_first_token, + "time_to_second_token_secs": time_to_second_token, + "time_to_third_token_secs": time_to_third_token, + "time_to_next_token_mean_secs": mean_time_to_next_token, + "first_compile_generate_time_secs": first_compile_generate_time, + "second_compile_generate_time_secs": second_compile_generate_time, + "third_compile_generate_time_secs": third_compile_generate_time, + "fourth_compile_generate_time_secs": fourth_compile_generate_time, + }, + ) +``` diff --git a/benchmark/benchmarks_entrypoint.py b/benchmark/benchmarks_entrypoint.py new file mode 100644 index 00000000000000..7925e2902834f7 --- /dev/null +++ b/benchmark/benchmarks_entrypoint.py @@ -0,0 +1,144 @@ +import argparse +import importlib.util +import logging +import os +from typing import Dict +import psycopg2 +import sys + +from psycopg2.extras import Json +from psycopg2.extensions import register_adapter + + +register_adapter(dict, Json) + + +class ImportModuleException(Exception): + pass + + +class MetricsRecorder: + def __init__(self, connection, logger: logging.Logger, branch: str, commit_id: str, commit_msg: str): + self.conn = connection + self.conn.autocommit = True + self.logger = logger + self.branch = branch + self.commit_id = commit_id + self.commit_msg = commit_msg + + def initialise_benchmark(self, metadata: Dict[str, str]) -> int: + """ + Creates a new benchmark, returns the benchmark id + """ + # gpu_name: str, model_id: str + with self.conn.cursor() as cur: + cur.execute( + "INSERT INTO benchmarks (branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s) RETURNING benchmark_id", + (self.branch, self.commit_id, self.commit_msg, metadata), + ) + benchmark_id = cur.fetchone()[0] + logger.debug(f"initialised benchmark #{benchmark_id}") + return benchmark_id + + def collect_device_measurements(self, benchmark_id: int, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes): + """ + Collect device metrics, such as CPU & GPU usage. These are "static", as in you cannot pass arbitrary arguments to the function. + """ + with self.conn.cursor() as cur: + cur.execute( + "INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)", + (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes), + ) + self.logger.debug( + f"inserted device measurements for benchmark #{benchmark_id} [CPU util: {cpu_util}, mem MBs: {mem_megabytes}, GPU util: {gpu_util}, GPU mem MBs: {gpu_mem_megabytes}]" + ) + + def collect_model_measurements(self, benchmark_id: int, measurements: Dict[str, float]): + with self.conn.cursor() as cur: + cur.execute( + """ + INSERT INTO model_measurements ( + benchmark_id, + measurements + ) VALUES (%s, %s) + """, + ( + benchmark_id, + measurements, + ), + ) + self.logger.debug(f"inserted model measurements for benchmark #{benchmark_id}: {measurements}") + + def close(self): + self.conn.close() + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.INFO) +formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def parse_arguments(): + """ + Parse command line arguments for the benchmarking CLI. + """ + parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.") + + parser.add_argument( + "branch", + type=str, + help="The branch name on which the benchmarking is performed.", + ) + + parser.add_argument( + "commit_id", + type=str, + help="The commit hash on which the benchmarking is performed.", + ) + + parser.add_argument( + "commit_msg", + type=str, + help="The commit message associated with the commit, truncated to 70 characters.", + ) + + args = parser.parse_args() + + return args.branch, args.commit_id, args.commit_msg + + +def import_from_path(module_name, file_path): + try: + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + except Exception as e: + raise ImportModuleException(f"failed to load python module: {e}") + + +if __name__ == "__main__": + benchmarks_folder_path = os.path.dirname(os.path.realpath(__file__)) + + branch, commit_id, commit_msg = parse_arguments() + + for entry in os.scandir(benchmarks_folder_path): + try: + if not entry.name.endswith(".py"): + continue + if entry.path == __file__: + continue + logger.debug(f"loading: {entry.name}") + module = import_from_path(entry.name.split(".")[0], entry.path) + logger.info(f"runnning benchmarks in: {entry.name}") + module.run_benchmark(logger, branch, commit_id, commit_msg) + except ImportModuleException as e: + logger.error(e) + except Exception as e: + logger.error(f"error running benchmarks for {entry.name}: {e}") diff --git a/benchmark/default.yml b/benchmark/default.yml new file mode 100644 index 00000000000000..f3f02cab34d1bd --- /dev/null +++ b/benchmark/default.yml @@ -0,0 +1,10 @@ +apiVersion: 1 + +providers: + - name: 'Transformers Benchmarks' + orgId: 1 + type: file + updateIntervalSeconds: 10 + allowUiUpdates: true + options: + path: /etc/grafana/dashboards diff --git a/benchmark/grafana_dashboard.json b/benchmark/grafana_dashboard.json index 3d579f7b368711..caaec78a522303 100644 --- a/benchmark/grafana_dashboard.json +++ b/benchmark/grafana_dashboard.json @@ -30,7 +30,7 @@ "title": "Go to data", "tooltip": "Go to data", "type": "link", - "url": "http://transformers-benchmarks.huggingface.co/d/fdz33iyzln9c0a/transformers-benchmarks?orgId=1&from=${StartTime}&to=${EndTime}" + "url": "http://transformers-benchmarks.hf.co/d/fdz33iyzln9c0a/transformers-benchmarks?orgId=1&from=${StartTime}&to=${EndTime}" } ], "liveNow": true, @@ -77,7 +77,7 @@ "properties": [ { "id": "custom.width", - "value": 196 + "value": 202 } ] }, @@ -101,7 +101,7 @@ "properties": [ { "id": "custom.width", - "value": 581 + "value": 524 } ] }, @@ -113,7 +113,19 @@ "properties": [ { "id": "custom.width", - "value": 379 + "value": 353 + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "model_id" + }, + "properties": [ + { + "id": "custom.width", + "value": 216 } ] } @@ -143,12 +155,14 @@ "targets": [ { "datasource": { - "type": "grafana-postgresql-datasource" + "default": true, + "type": "grafana-postgresql-datasource", + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT commit_id as commit_id, commit_message, gpu_name, created_at AS date FROM benchmarks WHERE branch = '${branch}' ORDER BY benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT commit_id, commit_message, metadata->>'gpu_name' as gpu_name, metadata->>'model_id' as model_id, created_at AS date FROM benchmarks WHERE branch = '${branch}' AND metadata->>'gpu_name' = '${gpu_name}' ORDER BY benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -306,13 +320,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'first_eager_forward_pass_time_secs' AS double precision) AS first_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'first_eager_forward_pass_time_secs' AS double precision) AS first_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -431,13 +446,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'second_eager_forward_pass_time_secs' AS double precision) AS second_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'second_eager_forward_pass_time_secs' AS double precision) AS second_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -565,13 +581,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'time_to_first_token_secs' AS double precision) AS time_to_first_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'time_to_first_token_secs' AS double precision) AS time_to_first_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -686,13 +703,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'time_to_second_token_secs' AS double precision) AS time_to_second_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'time_to_second_token_secs' AS double precision) AS time_to_second_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -807,13 +825,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'time_to_third_token_secs' AS double precision) AS time_to_third_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'time_to_third_token_secs' AS double precision) AS time_to_third_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -928,13 +947,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'time_to_next_token_mean_secs' AS double precision) AS time_to_next_token_mean_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'time_to_next_token_mean_secs' AS double precision) AS time_to_next_token_mean_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -1062,13 +1082,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'first_compile_generate_time_secs' AS double precision) AS first_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'first_compile_generate_time_secs' AS double precision) AS first_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -1183,13 +1204,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'second_compile_generate_time_secs' AS double precision) AS second_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'second_compile_generate_time_secs' AS double precision) AS second_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -1304,13 +1326,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'third_compile_generate_time_secs' AS double precision) AS third_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'third_compile_generate_time_secs' AS double precision) AS third_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -1425,13 +1448,14 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT CAST(m.measurements->'fourth_compile_generate_time_secs' AS double precision) AS fourth_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", + "rawSql": "SELECT CAST(m.measurements->'fourth_compile_generate_time_secs' AS double precision) AS fourth_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};", "refId": "A", "sql": { "columns": [ @@ -1480,11 +1504,7 @@ "id": 15, "panels": [ { - "datasource": { - "default": true, - "type": "grafana-postgresql-datasource", - "uid": "be28nkzirtb0gd" - }, + "datasource": {}, "fieldConfig": { "defaults": { "color": { @@ -1528,8 +1548,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" }, { "color": "red", @@ -1563,8 +1582,9 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", @@ -1665,11 +1685,7 @@ "type": "timeseries" }, { - "datasource": { - "default": true, - "type": "grafana-postgresql-datasource", - "uid": "be28nkzirtb0gd" - }, + "datasource": {}, "fieldConfig": { "defaults": { "color": { @@ -1713,8 +1729,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" }, { "color": "red", @@ -1748,8 +1763,9 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", @@ -1850,11 +1866,7 @@ "type": "timeseries" }, { - "datasource": { - "default": true, - "type": "grafana-postgresql-datasource", - "uid": "be28nkzirtb0gd" - }, + "datasource": {}, "fieldConfig": { "defaults": { "color": { @@ -1898,8 +1910,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" }, { "color": "red", @@ -1933,8 +1944,9 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", @@ -2035,11 +2047,7 @@ "type": "timeseries" }, { - "datasource": { - "default": true, - "type": "grafana-postgresql-datasource", - "uid": "be28nkzirtb0gd" - }, + "datasource": {}, "fieldConfig": { "defaults": { "color": { @@ -2083,8 +2091,7 @@ "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" }, { "color": "red", @@ -2118,8 +2125,9 @@ "targets": [ { "datasource": { + "default": true, "type": "grafana-postgresql-datasource", - "uid": "bdz2yss7sxo1sc" + "uid": "be28nkzirtb0gd" }, "editorMode": "code", "format": "table", @@ -2224,7 +2232,6 @@ "type": "row" } ], - "refresh": "", "schemaVersion": 39, "tags": [], "templating": { @@ -2236,6 +2243,7 @@ "value": "main" }, "datasource": { + "default": true, "type": "grafana-postgresql-datasource", "uid": "be28nkzirtb0gd" }, @@ -2248,7 +2256,7 @@ "name": "branch", "options": [], "query": "SELECT DISTINCT branch FROM benchmarks;", - "refresh": 2, + "refresh": 1, "regex": "", "skipUrlSync": false, "sort": 0, @@ -2261,6 +2269,7 @@ "value": "1729701492845" }, "datasource": { + "default": true, "type": "grafana-postgresql-datasource", "uid": "be28nkzirtb0gd" }, @@ -2281,10 +2290,11 @@ { "current": { "selected": false, - "text": "1730120430069", - "value": "1730120430069" + "text": "1730393397577", + "value": "1730393397577" }, "datasource": { + "default": true, "type": "grafana-postgresql-datasource", "uid": "be28nkzirtb0gd" }, @@ -2312,15 +2322,16 @@ "type": "grafana-postgresql-datasource", "uid": "be28nkzirtb0gd" }, - "definition": "SELECT DISTINCT gpu_name FROM benchmarks;", + "definition": "SELECT DISTINCT metadata->>'gpu_name' FROM benchmarks;", + "description": "", "hide": 0, "includeAll": false, "label": "GPU", "multi": false, "name": "gpu_name", "options": [], - "query": "SELECT DISTINCT gpu_name FROM benchmarks;", - "refresh": 2, + "query": "SELECT DISTINCT metadata->>'gpu_name' FROM benchmarks;", + "refresh": 1, "regex": "", "skipUrlSync": false, "sort": 0, @@ -2328,7 +2339,7 @@ }, { "current": { - "selected": false, + "selected": true, "text": "10", "value": "10" }, @@ -2359,6 +2370,6 @@ "timezone": "browser", "title": "Transformers benchmarks", "uid": "fdz33iyzln9c0a", - "version": 4, + "version": 10, "weekStart": "" } diff --git a/benchmark/grafana_datasource.yaml b/benchmark/grafana_datasource.yaml new file mode 100644 index 00000000000000..25f36254104ab5 --- /dev/null +++ b/benchmark/grafana_datasource.yaml @@ -0,0 +1,17 @@ +apiVersion: 1 +datasources: + - name: grafana-postgresql-datasource + uid: be28nkzirtb0gd + type: postgres + url: $GRAFANA_POSTGRES_DATASOURCE_URL + user: $GRAFANA_POSTGRES_DATASOURCE_USER + secureJsonData: + password: $GRAFANA_POSTGRES_DATASOURCE_PWD + jsonData: + database: metrics + maxOpenConns: 100 + maxIdleConns: 100 + maxIdleConnsAuto: true + connMaxLifetime: 14400 + postgresVersion: 1000 + timescaledb: false diff --git a/benchmark/init_db.sql b/benchmark/init_db.sql index 573cc11518e857..a7864c4af183b6 100644 --- a/benchmark/init_db.sql +++ b/benchmark/init_db.sql @@ -3,7 +3,7 @@ CREATE TABLE IF NOT EXISTS benchmarks ( branch VARCHAR(255), commit_id VARCHAR(72), commit_message VARCHAR(70), - gpu_name VARCHAR(255), + metadata jsonb, created_at timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC') ); diff --git a/benchmark/llama.py b/benchmark/llama.py index 4a2c57422e6ffb..bbe1afefd5ef1b 100644 --- a/benchmark/llama.py +++ b/benchmark/llama.py @@ -1,71 +1,25 @@ -import argparse -import json -import logging +from logging import Logger import os -import sys -from statistics import mean from threading import Event, Thread from time import perf_counter, sleep from typing import Optional +from benchmarks_entrypoint import MetricsRecorder import gpustat import psutil import psycopg2 import torch from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache -from psycopg2.extras import Json -from psycopg2.extensions import register_adapter os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.INFO) -formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s") -handler.setFormatter(formatter) -logger.addHandler(handler) - os.environ["TOKENIZERS_PARALLELISM"] = "1" torch.set_float32_matmul_precision("high") -register_adapter(dict, Json) - - -def parse_arguments(): - """ - Parse command line arguments for the benchmarking CLI. - """ - parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.") - - parser.add_argument( - "branch", - type=str, - help="The branch name on which the benchmarking is performed.", - ) - - parser.add_argument( - "commit_id", - type=str, - help="The commit hash on which the benchmarking is performed.", - ) - parser.add_argument( - "commit_msg", - type=str, - help="The commit message associated with the commit, truncated to 70 characters.", - ) - args = parser.parse_args() - - return args.branch, args.commit_id, args.commit_msg - - -def collect_metrics(benchmark_id, continue_metric_collection): +def collect_metrics(benchmark_id, continue_metric_collection, metrics_recorder): p = psutil.Process(os.getpid()) - conn = psycopg2.connect("dbname=metrics") - cur = conn.cursor() while not continue_metric_collection.is_set(): with p.oneshot(): cpu_util = p.cpu_percent() @@ -73,47 +27,41 @@ def collect_metrics(benchmark_id, continue_metric_collection): gpu_stats = gpustat.GPUStatCollection.new_query() gpu_util = gpu_stats[0]["utilization.gpu"] gpu_mem_megabytes = gpu_stats[0]["memory.used"] - cur.execute( - "INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)", - (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes), + metrics_recorder.collect_device_measurements( + benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes ) sleep(0.01) - conn.commit() - conn.close() -def run_benchmark(branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100): +def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100): continue_metric_collection = Event() metrics_thread = None + model_id = "meta-llama/Llama-2-7b-hf" + metrics_recorder = MetricsRecorder(psycopg2.connect("dbname=metrics"), logger, branch, commit_id, commit_msg) try: gpu_stats = gpustat.GPUStatCollection.new_query() gpu_name = gpu_stats[0]["name"] - conn = psycopg2.connect("dbname=metrics") - cur = conn.cursor() - cur.execute( - "INSERT INTO benchmarks (branch, commit_id, commit_message, gpu_name) VALUES (%s, %s, %s, %s) RETURNING benchmark_id", - (branch, commit_id, commit_msg, gpu_name), + benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id}) + logger.info(f"running benchmark #{benchmark_id} on {gpu_name} for {model_id}") + metrics_thread = Thread( + target=collect_metrics, + args=[benchmark_id, continue_metric_collection, metrics_recorder], ) - conn.commit() - benchmark_id = cur.fetchone()[0] - logger.info(f"running benchmark #{benchmark_id} on {gpu_name}") - metrics_thread = Thread(target=collect_metrics, args=[benchmark_id, continue_metric_collection]) metrics_thread.start() logger.info("started background thread to fetch device metrics") os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling device = "cuda" - ckpt = "meta-llama/Llama-2-7b-hf" logger.info("downloading weights") # This is to avoid counting download in model load time measurement - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16) + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16) gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1) logger.info("loading model") start = perf_counter() model = AutoModelForCausalLM.from_pretrained( - ckpt, torch_dtype=torch.float16, generation_config=gen_config + model_id, torch_dtype=torch.float16, generation_config=gen_config ).eval() model.to(device) torch.cuda.synchronize() @@ -121,7 +69,7 @@ def run_benchmark(branch: str, commit_id: str, commit_msg: str, num_tokens_to_ge model_load_time = end - start logger.info(f"loaded model in: {model_load_time}s") - tokenizer = AutoTokenizer.from_pretrained(ckpt) + tokenizer = AutoTokenizer.from_pretrained(model_id) prompt = "Why dogs are so cute?" inputs = tokenizer(prompt, return_tensors="pt").to(device) @@ -368,41 +316,27 @@ def decode_one_token(model, cur_token, cache_position, past_key_values): logger.info(f"completed second compile generation in: {fourth_compile_generate_time}s") logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}") - cur.execute( - """ - INSERT INTO model_measurements ( - benchmark_id, - measurements - ) VALUES (%s, %s) - """, - ( - benchmark_id, - { - "model_load_time": model_load_time, - "first_eager_forward_pass_time_secs": first_eager_fwd_pass_time, - "second_eager_forward_pass_time_secs": second_eager_fwd_pass_time, - "first_eager_generate_time_secs": first_eager_generate_time, - "second_eager_generate_time_secs": second_eager_generate_time, - "time_to_first_token_secs": time_to_first_token, - "time_to_second_token_secs": time_to_second_token, - "time_to_third_token_secs": time_to_third_token, - "time_to_next_token_mean_secs": mean_time_to_next_token, - "first_compile_generate_time_secs": first_compile_generate_time, - "second_compile_generate_time_secs": second_compile_generate_time, - "third_compile_generate_time_secs": third_compile_generate_time, - "fourth_compile_generate_time_secs": fourth_compile_generate_time, - }, - ), + metrics_recorder.collect_model_measurements( + benchmark_id, + { + "model_load_time": model_load_time, + "first_eager_forward_pass_time_secs": first_eager_fwd_pass_time, + "second_eager_forward_pass_time_secs": second_eager_fwd_pass_time, + "first_eager_generate_time_secs": first_eager_generate_time, + "second_eager_generate_time_secs": second_eager_generate_time, + "time_to_first_token_secs": time_to_first_token, + "time_to_second_token_secs": time_to_second_token, + "time_to_third_token_secs": time_to_third_token, + "time_to_next_token_mean_secs": mean_time_to_next_token, + "first_compile_generate_time_secs": first_compile_generate_time, + "second_compile_generate_time_secs": second_compile_generate_time, + "third_compile_generate_time_secs": third_compile_generate_time, + "fourth_compile_generate_time_secs": fourth_compile_generate_time, + }, ) - conn.commit() - conn.close() except Exception as e: logger.error(f"Caught exception: {e}") continue_metric_collection.set() if metrics_thread is not None: metrics_thread.join() - - -if __name__ == "__main__": - branch, commit_id, commit_msg = parse_arguments() - run_benchmark(branch, commit_id, commit_msg, num_tokens_to_generate=20) + metrics_recorder.close() From 9613933b022ddbf085e2c593ed4ceea4c734179a Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 19 Dec 2024 03:18:17 +0800 Subject: [PATCH 04/23] Add the Bamba Model (#34982) * initial commit for PR Co-authored-by: Gabe Goodhart * rename dynamic cache Signed-off-by: Yu Chin Fabian Lim * add more unit tests Signed-off-by: Yu Chin Fabian Lim * add integration test Signed-off-by: Yu Chin Fabian Lim * add integration test Signed-off-by: Yu Chin Fabian Lim * Add modular bamba file * Remove trainer changes from unrelated PR * Modify modular and cofig to get model running * Fix some CI errors and beam search * Fix a plethora of bugs from CI/docs/etc * Add bamba to models with special caches * Updat to newer mamba PR for mamba sublayer * fix test_left_padding_compatibility Signed-off-by: Yu Chin Fabian Lim * fix style Signed-off-by: Yu Chin Fabian Lim * fix remaining tests Signed-off-by: Yu Chin Fabian Lim * missed this test Signed-off-by: Yu Chin Fabian Lim * ran make style Signed-off-by: Yu Chin Fabian Lim * move slow tag to integration obj Signed-off-by: Yu Chin Fabian Lim * make style Signed-off-by: Yu Chin Fabian Lim * address comments Signed-off-by: Yu Chin Fabian Lim * fix modular Signed-off-by: Yu Chin Fabian Lim * left out one part of modular Signed-off-by: Yu Chin Fabian Lim * change model Signed-off-by: Yu Chin Fabian Lim * Make Rotary modular as well * Update bamba.md Added overview, update Model inference card and added config * Update bamba.md * Update bamba.md * Update bamba.md Minor fixes * Add docs for config and model back Signed-off-by: Antoni Viros i Martin * Add warning when using fast kernels * replaced generate example Signed-off-by: Yu Chin Fabian Lim * Address comments from PR Signed-off-by: Antoni Viros i Martin * Propagate attention fixes Signed-off-by: Antoni Viros i Martin * Fix attention interfaces to the new API Signed-off-by: Antoni Viros i Martin * Fix API for decoder layer Signed-off-by: Antoni Viros i Martin * Remove extra weights Signed-off-by: Antoni Viros i Martin --------- Signed-off-by: Yu Chin Fabian Lim Signed-off-by: Antoni Viros i Martin Co-authored-by: Gabe Goodhart Co-authored-by: Antoni Viros i Martin Co-authored-by: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Co-authored-by: Antoni Viros --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/bamba.md | 64 + docs/source/en/perf_infer_gpu_one.md | 2 + src/transformers/__init__.py | 10 + src/transformers/generation/utils.py | 1 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/bamba/__init__.py | 28 + .../models/bamba/configuration_bamba.py | 206 +++ .../bamba/convert_mamba_ssm_checkpoint.py | 273 +++ .../models/bamba/modeling_bamba.py | 1615 +++++++++++++++++ .../models/bamba/modular_bamba.py | 1303 +++++++++++++ src/transformers/utils/dummy_pt_objects.py | 21 + tests/generation/test_utils.py | 1 + tests/models/bamba/__init__.py | 0 tests/models/bamba/test_modeling_bamba.py | 603 ++++++ utils/check_config_attributes.py | 3 + 19 files changed, 4138 insertions(+) create mode 100644 docs/source/en/model_doc/bamba.md create mode 100644 src/transformers/models/bamba/__init__.py create mode 100644 src/transformers/models/bamba/configuration_bamba.py create mode 100644 src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py create mode 100644 src/transformers/models/bamba/modeling_bamba.py create mode 100644 src/transformers/models/bamba/modular_bamba.py create mode 100644 tests/models/bamba/__init__.py create mode 100644 tests/models/bamba/test_modeling_bamba.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 435b482df599cf..c30dfd3fbabc97 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -322,6 +322,8 @@ sections: - local: model_doc/albert title: ALBERT + - local: model_doc/bamba + title: Bamba - local: model_doc/bart title: BART - local: model_doc/barthez diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 3bd1c286d43240..0bd81e9d61be29 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -66,6 +66,7 @@ Flax), PyTorch, and/or TensorFlow. | [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ | | [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ | | [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ | +| [Bamba](model_doc/bamba) | ✅ | ❌ | ❌ | | [Bark](model_doc/bark) | ✅ | ❌ | ❌ | | [BART](model_doc/bart) | ✅ | ✅ | ✅ | | [BARThez](model_doc/barthez) | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/bamba.md b/docs/source/en/model_doc/bamba.md new file mode 100644 index 00000000000000..4ea8475edb885a --- /dev/null +++ b/docs/source/en/model_doc/bamba.md @@ -0,0 +1,64 @@ + + +# Bamba + + +## Overview + +Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality. + +Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba). + +## BambaConfig + +| Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings | +|-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------| +| Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True | + +[[autodoc]] BambaConfig + + + +## BambaForCausalLM + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B") +tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B") + +message = ["Mamba is a snake with following properties "] +inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False) +response = model.generate(**inputs, max_new_tokens=64) +print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) +``` + +[[autodoc]] BambaForCausalLM + - forward + +This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim). diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index cbb498070d69e5..4f9cace5b8d30d 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -39,6 +39,7 @@ FlashAttention-2 is experimental and may change considerably in future versions. FlashAttention-2 is currently supported for the following architectures: * [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration) * [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) +* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) * [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel) @@ -220,6 +221,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel) * [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration) * [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) +* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 920dc334dbb2a4..6a180a90bbbaa2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -193,6 +193,7 @@ "AutoTokenizer", ], "models.autoformer": ["AutoformerConfig"], + "models.bamba": ["BambaConfig"], "models.bark": [ "BarkCoarseConfig", "BarkConfig", @@ -1540,6 +1541,13 @@ "AutoformerPreTrainedModel", ] ) + _import_structure["models.bamba"].extend( + [ + "BambaForCausalLM", + "BambaModel", + "BambaPreTrainedModel", + ] + ) _import_structure["models.bark"].extend( [ "BarkCausalModel", @@ -5104,6 +5112,7 @@ from .models.autoformer import ( AutoformerConfig, ) + from .models.bamba import BambaConfig from .models.bark import ( BarkCoarseConfig, BarkConfig, @@ -6493,6 +6502,7 @@ AutoformerModel, AutoformerPreTrainedModel, ) + from .models.bamba import BambaForCausalLM, BambaModel, BambaPreTrainedModel from .models.bark import ( BarkCausalModel, BarkCoarseModel, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index fe634141eca09b..05627e23de11ff 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1693,6 +1693,7 @@ def _supports_default_dynamic_cache(self) -> bool: self._supports_cache_class and "jamba" not in self.__class__.__name__.lower() and "zamba" not in self.__class__.__name__.lower() + and "bamba" not in self.__class__.__name__.lower() ) def _prepare_cache_for_generation( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5eb74fab5abe71..5b3c648428359d 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -20,6 +20,7 @@ audio_spectrogram_transformer, auto, autoformer, + bamba, bark, bart, barthez, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d7d8281c2e3f03..8aba0e75b2690b 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -39,6 +39,7 @@ ("aria_text", "AriaTextConfig"), ("audio-spectrogram-transformer", "ASTConfig"), ("autoformer", "AutoformerConfig"), + ("bamba", "BambaConfig"), ("bark", "BarkConfig"), ("bart", "BartConfig"), ("beit", "BeitConfig"), @@ -337,6 +338,7 @@ ("aria_text", "AriaText"), ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"), ("autoformer", "Autoformer"), + ("bamba", "Bamba"), ("bark", "Bark"), ("bart", "BART"), ("barthez", "BARThez"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5d41ad42beea7e..770e4ea0775f76 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -39,6 +39,7 @@ ("aria_text", "AriaTextModel"), ("audio-spectrogram-transformer", "ASTModel"), ("autoformer", "AutoformerModel"), + ("bamba", "BambaModel"), ("bark", "BarkModel"), ("bart", "BartModel"), ("beit", "BeitModel"), @@ -471,6 +472,7 @@ [ # Model for Causal LM mapping ("aria_text", "AriaTextForCausalLM"), + ("bamba", "BambaForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), ("bert-generation", "BertGenerationDecoder"), diff --git a/src/transformers/models/bamba/__init__.py b/src/transformers/models/bamba/__init__.py new file mode 100644 index 00000000000000..c3920da849a333 --- /dev/null +++ b/src/transformers/models/bamba/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bamba import * + from .modeling_bamba import * + from .processing_bamba import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/bamba/configuration_bamba.py b/src/transformers/models/bamba/configuration_bamba.py new file mode 100644 index 00000000000000..f84d63ec04a9c7 --- /dev/null +++ b/src/transformers/models/bamba/configuration_bamba.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bamba model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BambaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BambaModel`]. It is used to instantiate a + BambaModel model according to the specified arguments, defining the model architecture. Instantiating a configuration + with defaults taken from [ibm-fms/Bamba-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/Bamba-9.8b-2.2T-hf). + + The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU. + The checkpoints are jointly trained by IBM, Princeton, and UIUC. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 128000): + Vocabulary size of the Bamba model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BambaModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + max_position_embeddings (`int`, *optional*, defaults to 262144): + Max cached sequence length for the model + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attn_layer_indices (`list`, *optional*): + Specifies the layer indices that will have full attention. Must contain values at most num_hidden_layers. + mamba_n_heads (`int`, *optional*, defaults to 128): + The number of mamba heads used in the v2 implementation. + mamba_d_head (`int`, *optional*, defaults to `"auto"`): + Head embeddding dimension size + mamba_n_groups (`int`, *optional*, defaults to 1): + The number of the mamba groups used in the v2 implementation. + mamba_d_state (`int`, *optional*, defaults to 256): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_chunk_size (`int`, *optional*, defaults to 256): + The chunks in which to break the sequence when doing prefill/training + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + + """ + + model_type = "bamba" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128000, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + max_position_embeddings=262144, + attention_dropout=0.0, + attn_layer_indices=None, + mamba_n_heads=128, + mamba_d_head="auto", + mamba_n_groups=1, + mamba_d_state=256, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.attention_bias = False + self.mlp_bias = False + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.attn_layer_indices = attn_layer_indices + self.rope_theta = 10000.0 + self.rope_scaling = None + self.partial_rotary_factor = 0.5 + + mamba_intermediate = mamba_expand * hidden_size + + if mamba_intermediate % mamba_n_heads != 0: + raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size") + + # for the mamba_v2, must satisfy the following + if mamba_d_head == "auto": + mamba_d_head = mamba_intermediate // mamba_n_heads + + if mamba_d_head * mamba_n_heads != mamba_intermediate: + raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size") + + self.mamba_n_heads = mamba_n_heads + self.mamba_d_head = mamba_d_head + self.mamba_n_groups = mamba_n_groups + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return [ + "attention" if (self.attn_layer_indices and i in self.attn_layer_indices) else "mamba" + for i in range(self.num_hidden_layers) + ] + + +__all__ = ["BambaConfig"] diff --git a/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py b/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py new file mode 100644 index 00000000000000..a7b8cfc782907b --- /dev/null +++ b/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed.""" + +import argparse +import json +import os +import re +from os import path +from typing import Dict, Union + +import torch +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import save_file + +from transformers import AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + +from .configuration_bamba import BambaConfig + + +def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]: + state_dict = {} + + for orig_k, param in original_sd.items(): + k = orig_k.replace("backbone", "model") + + # for embeddings + k = k.replace("embedding", "embed_tokens") + + # for mixer + k = k.replace("mixer", "mamba") + + # for final layernorm + k = k.replace("norm_f", "final_layernorm") + + # for block layernorm + k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k) + k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k) + + # for mlp + k = k.replace("mlp.fc2", "feed_forward.down_proj") + + if "mlp.fc1" in k: + param, param2 = torch.chunk(param, 2, dim=0) + k2 = k.replace("mlp.fc1", "feed_forward.gate_proj") + state_dict[k2] = param2 + k = k.replace("mlp.fc1", "feed_forward.up_proj") + + if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or ( + "out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd + ): + # then this must be a mamba + pass + else: + # for attn + # - because mixer was replaced to mamba above + k = k.replace("mamba.out_proj", "self_attn.o_proj") + if "mamba.in_proj" in k: + m, n = param.shape + d = (m - n) // 2 + param, param2, param3 = torch.split(param, [n, d, d], dim=0) + k2 = k.replace("mamba.in_proj", "self_attn.k_proj") + state_dict[k2] = param2 + k2 = k.replace("mamba.in_proj", "self_attn.v_proj") + state_dict[k2] = param3 + k = k.replace("mamba.in_proj", "self_attn.q_proj") + + state_dict[k] = param + + return state_dict + + +# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py +def convert_ssm_config_to_hf_config( + config_ssm: Dict, + **kwargs, +) -> BambaConfig: + """Convert a config from mamba_ssm to a BambaConfig from here.""" + hf_config: BambaConfig = BambaConfig(**kwargs) + + hf_config.architectures = ["BambaForCausalLM"] + + # Set important values from config and recalculate other resulting entries + hf_config.hidden_size = config_ssm["d_model"] + hf_config.intermediate_size = config_ssm["d_intermediate"] + hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head + hf_config.num_hidden_layers = config_ssm["n_layer"] + hf_config.tie_word_embeddings = config_ssm["tie_embeddings"] + + # currently this script assumes config_ssm belongs to v2 + if config_ssm["ssm_cfg"].get("layer") != "Mamba2": + raise ValueError("Conversion script only supports Mamba2") + + # Set attention values + attn_cfg = config_ssm.get("attn_cfg") + if attn_cfg: + assert attn_cfg["causal"], "Only support non-causal attention." + assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias." + assert not attn_cfg["out_proj_bias"], "Only support no out bias." + hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"] + hf_config.num_attention_heads = attn_cfg["num_heads"] + hf_config.num_key_value_heads = attn_cfg["num_heads_kv"] + + attention_layer_indices = config_ssm.get("attn_layer_idx") + if attention_layer_indices: + hf_config.attn_layer_indices = attention_layer_indices + + # Padded vocab size, mostly of 16 but 32 is also very common in different models + vocab_size = config_ssm["vocab_size"] + pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"] + if (vocab_size % pad_vocab_size_multiple) != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + hf_config.vocab_size = vocab_size + + return hf_config + + +def save_single_safetensor( + state_dict: Dict, + save_directory: str, + metadata: Dict, +): + save_file( + state_dict, + os.path.join(save_directory, SAFE_WEIGHTS_NAME), + metadata, + ) + + +def save_sharded_safetensors( + state_dict: Dict, + save_directory: str, + metadata: Dict, + max_shard_size: Union[int, str] = "5GB", +): + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + + +# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py +def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( + mamba_ssm_checkpoint_path: str, + precision: str, + output_dir: str, + tokenizer_path: str = None, + save_model: Union[bool, str] = True, +) -> None: + # load tokenizer if provided, this will be used to set the + # token_ids in the config file + token_ids = {} + if tokenizer_path: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + for key in [ + "bos_token_id", + "eos_token_id", + "pad_token_id", + ]: + id = getattr(tokenizer, key, None) + if id: + token_ids[key] = id + + # there are some configs unsettable by mamba_ssn config, so + # if there are changes from the defaults, have to pass them into + # the function + unsettables = { + "mamba_d_head": 64, + "mamba_d_state": 128, + "mamba_n_groups": 1, + "rms_norm_eps": 1e-5, + } + + # Load and save config based on name + config_path = path.join(mamba_ssm_checkpoint_path, "config.json") + with open(config_path, "r", encoding="utf-8") as json_file: + config = json.load(json_file) + + # convert the config + hf_config = convert_ssm_config_to_hf_config( + config_ssm=config, + **token_ids, + **unsettables, + ) + hf_config.save_pretrained(output_dir) + + # Load state dict of the original model and transfer to hf model + state_dict = torch.load( + path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"), + map_location="cpu", + weights_only=True, + ) + # FIXME: allow other parameters to pass in + state_dict = convert_state_dict_from_mamba_ssm(state_dict) + + # Save new model to pytorch_dump_path + dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16) + + save_file_fn = None + if isinstance(save_model, bool) and save_model: + save_file_fn = save_single_safetensor + elif isinstance(save_model, str) and save_model == "sharded": + save_file_fn = save_sharded_safetensors + + if save_file_fn: + save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--mamba_ssm_checkpoint_directory", + type=str, + required=True, + help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.", + ) + parser.add_argument( + "-p", + "--precision", + type=str, + default="fp16", + const="fp16", + required=True, + choices=("fp32", "fp16", "bf16"), + help="The precision the model will be saved in. Select from fp32, fp16 or bf16.", + ) + parser.add_argument( + "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." + ) + parser.add_argument( + "-t", + "--tokenizer_model_path", + type=str, + default=None, + required=False, + help="Path to a the tokenizer file.", + ) + args = parser.parse_args() + + convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( + args.mamba2_checkpoint_directory, + args.precision, + args.output_dir, + ) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py new file mode 100644 index 00000000000000..c89d8d7853008d --- /dev/null +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -0,0 +1,1615 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/bamba/modular_bamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_bamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn + +import transformers.models.jamba.modeling_jamba as modeling_jamba +from transformers.activations import ACT2FN + +from ...cache_utils import Cache # we need __iter__ and __len__ of pkv +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_mamba_2_ssm_available, +) +from .configuration_bamba import BambaConfig + + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "BambaConfig" + + +# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer +class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): + super().__init__(config, batch_size, dtype, device) + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + conv_kernel_size = config.mamba_d_conv + ssm_state_size = config.mamba_d_state + + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "mamba": + self.conv_states += [ + torch.zeros( + batch_size, + (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), + conv_kernel_size, + device=device, + dtype=dtype, + ) + ] + self.ssm_states += [ + torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + ssm_state_size, + device=device, + dtype=dtype, + ) + ] + else: + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + +class BambaRotaryEmbedding(nn.Module): + def __init__( + self, + config: BambaConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class BambaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BambaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class BambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class BambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + The are a few differences between this and Mamba2Mixer: + - The variable use_precomputed_states is slightly different due to the HybridCache structure + - There's a few non-obvious bugs fixed with batching in the slow path that exist in main + - Some extra variables that our layer doesn't need have been removed + - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged + """ + + def __init__(self, config: BambaConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for Bamba will be used when running the model on a GPU") + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + # storing the states + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class BambaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class BambaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BambaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class BambaDecoderLayer(nn.Module): + def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): + super().__init__() + + num_experts = 1 + ffn_layer_class = BambaMLP if num_experts == 1 else None + self.feed_forward = ffn_layer_class(config) + self.input_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.layer_type = layer_type + if layer_type == "mamba": + self.mamba = BambaMixer(config=config, layer_idx=layer_idx) + elif layer_type == "attention": + self.self_attn = BambaAttention(config, layer_idx) + else: + raise ValueError("Invalid layer_type") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # this is a hybrid decoder layer + if self.layer_type == "mamba": + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + cache_position=cache_position, + attention_mask=attention_mask, + ) + self_attn_weights = None + elif self.layer_type == "attention": + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +BAMBA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BambaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare BambaModel outputting raw hidden-states without any specific head on top.", + BAMBA_START_DOCSTRING, +) +class BambaPreTrainedModel(PreTrainedModel): + config_class = BambaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +BAMBA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Bamba Model outputting raw hidden-states without any specific head on top.", + BAMBA_START_DOCSTRING, +) +# Adapted from transformers.models.jamba.modeling_jamba.JambaModel +class BambaModel(BambaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambaDecoderLayer`] + + Args: + config: BambaConfig + """ + + def __init__(self, config: BambaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i])) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BambaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + layer_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: HybridMambaAttentionDynamicCache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[ + :, :, -sequence_length:, : + ].to(dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = BambaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BambaForCausalLM + + >>> model = BambaForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if not empty_past_kv: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "num_logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"] diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py new file mode 100644 index 00000000000000..7fb35f48fb3b76 --- /dev/null +++ b/src/transformers/models/bamba/modular_bamba.py @@ -0,0 +1,1303 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Bamba model.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +import transformers.models.jamba.modeling_jamba as modeling_jamba +from transformers.activations import ACT2FN +from transformers.models.jamba.modeling_jamba import JambaAttentionDecoderLayer +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + rotate_half, +) +from transformers.models.mamba2.modeling_mamba2 import ( + MambaRMSNormGated, + pad_tensor_by_size, + reshape_into_chunks, + segment_sum, +) + +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_flash_attn_2_available, + is_mamba_2_ssm_available, +) +from .configuration_bamba import BambaConfig + + +if is_flash_attn_2_available(): + pass + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BambaConfig" + + +# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer +class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): + super().__init__(config, batch_size, dtype, device) + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + conv_kernel_size = config.mamba_d_conv + ssm_state_size = config.mamba_d_state + + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "mamba": + self.conv_states += [ + torch.zeros( + batch_size, + (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), + conv_kernel_size, + device=device, + dtype=dtype, + ) + ] + self.ssm_states += [ + torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + ssm_state_size, + device=device, + dtype=dtype, + ) + ] + else: + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + +class BambaRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class BambaAttention(LlamaAttention): + pass + + +class BambaRMSNormGated(MambaRMSNormGated): + pass + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class BambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + The are a few differences between this and Mamba2Mixer: + - The variable use_precomputed_states is slightly different due to the HybridCache structure + - There's a few non-obvious bugs fixed with batching in the slow path that exist in main + - Some extra variables that our layer doesn't need have been removed + - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged + """ + + def __init__(self, config: BambaConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for Bamba will be used when running the model on a GPU") + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + # storing the states + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class BambaMLP(LlamaMLP): + pass + + +class BambaRMSNorm(LlamaRMSNorm): + pass + + +class BambaDecoderLayer(JambaAttentionDecoderLayer): + def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): + super().__init__() + + del self.self_attn + + num_experts = 1 + ffn_layer_class = BambaMLP if num_experts == 1 else None + self.feed_forward = ffn_layer_class(config) + + self.layer_type = layer_type + if layer_type == "mamba": + self.mamba = BambaMixer(config=config, layer_idx=layer_idx) + elif layer_type == "attention": + self.self_attn = BambaAttention(config, layer_idx) + else: + raise ValueError("Invalid layer_type") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # this is a hybrid decoder layer + if self.layer_type == "mamba": + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + cache_position=cache_position, + attention_mask=attention_mask, + ) + self_attn_weights = None + elif self.layer_type == "attention": + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +BAMBA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BambaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare BambaModel outputting raw hidden-states without any specific head on top.", + BAMBA_START_DOCSTRING, +) +class BambaPreTrainedModel(PreTrainedModel): + config_class = BambaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +BAMBA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Bamba Model outputting raw hidden-states without any specific head on top.", + BAMBA_START_DOCSTRING, +) +# Adapted from transformers.models.jamba.modeling_jamba.JambaModel +class BambaModel(BambaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambaDecoderLayer`] + + Args: + config: BambaConfig + """ + + def __init__(self, config: BambaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i])) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BambaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + layer_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: HybridMambaAttentionDynamicCache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[ + :, :, -sequence_length:, : + ].to(dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +class BambaForCausalLM(LlamaForCausalLM): + @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BambaForCausalLM + + >>> model = BambaForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return super().forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + cache_position, + num_logits_to_keep, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if not empty_past_kv: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "num_logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 823c51a290713d..c9a49d737d092e 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1167,6 +1167,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class BambaForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BambaModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BambaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BarkCausalModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index bf56578a164c94..e85f2663624740 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2313,6 +2313,7 @@ def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1 # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the # standard cache format (e.g.gptbigcode ) models_without_standard_cache = ( + "bamba", "ctrl", "fsmt", "gptbigcode", diff --git a/tests/models/bamba/__init__.py b/tests/models/bamba/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py new file mode 100644 index 00000000000000..45819e66b73c08 --- /dev/null +++ b/tests/models/bamba/test_modeling_bamba.py @@ -0,0 +1,603 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Bamba model.""" + +import inspect +import unittest + +import pytest +from parameterized import parameterized + +from transformers import AutoTokenizer, BambaConfig, is_torch_available +from transformers.testing_utils import ( + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + BambaForCausalLM, + BambaModel, + ) + from transformers.models.bamba.modeling_bamba import ( + HybridMambaAttentionDynamicCache, + ) + + +class BambaModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + hidden_act="silu", + attention_dropout=0.0, + attn_layer_indices=None, + attn_rotary_emb=8, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + num_labels=3, + pad_token_id=0, + mamba_n_groups=1, + mamba_n_heads=16, + mamba_d_state=16, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=16, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attn_layer_indices = attn_layer_indices + self.attn_rotary_emb = attn_rotary_emb + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.pad_token_id = pad_token_id + self.scope = scope + self.mamba_n_groups = mamba_n_groups + self.mamba_n_heads = mamba_n_heads + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_labels = None + if self.use_labels: + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, input_ids, input_mask, token_labels + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + token_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def get_config(self): + # Fix for SDPA tests, force at least 4 layers + if self.num_hidden_layers < 4: + self.num_hidden_layers = 4 + if self.attn_layer_indices is None: + d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0] + if len(d) == 0: + raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.") + d = d[-1] # get the largest divisor + self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)] + + return BambaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + attention_dropout=self.attention_dropout, + attn_layer_indices=self.attn_layer_indices, + attn_rotary_emb=self.attn_rotary_emb, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + mamba_n_groups=self.mamba_n_groups, + mamba_n_heads=self.mamba_n_heads, + mamba_d_state=self.mamba_d_state, + mamba_d_conv=self.mamba_d_conv, + mamba_expand=self.mamba_expand, + mamba_chunk_size=self.mamba_chunk_size, + ) + + def create_and_check_model( + self, + config, + input_ids, + input_mask, + token_labels, + ): + model = BambaModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + input_mask, + token_labels, + ): + model = BambaForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids, labels=token_labels) + result = model(input_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + input_mask, + token_labels, + ): + # config.is_decoder = True + # config.add_cross_attention = True + model = BambaForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + # Attention: Jamba needs the cache to be initialized to return a cache! + past_key_values = HybridMambaAttentionDynamicCache( + config, input_ids.shape[0], model.dtype, device=model.device + ) + outputs = model( + input_ids, + attention_mask=input_mask, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + cache_position=torch.arange( + input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device + ), + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + +@require_torch +class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + BambaModel, + BambaForCausalLM, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (BambaForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": BambaModel, + "text-generation": BambaForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + def setUp(self): + self.model_tester = BambaModelTester(self) + self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_casual_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_initialization(self): + r""" + Overriding the test_initialization test as the A_log and D params of the Bamba mixer are initialized differently + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if "A_log" in name: + A = torch.arange(1, config.mamba_n_heads + 1, dtype=torch.float32)[None, :] + self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5)) + elif "D" in name: + D = torch.ones(config.mamba_n_heads, dtype=torch.float32) + self.assertTrue(torch.allclose(param.data, D, atol=1e-5, rtol=1e-5)) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_mismatched_shapes_have_properly_initialized_weights(self): + r""" + Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the + Bamba mixer are initialized differently and we tested that in test_initialization + """ + self.skipTest(reason="Cumbersome and redundant for Bamba") + + def test_attention_outputs(self): + r""" + Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), expected_num_attentions) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + @unittest.skip(reason="Bamba has its own special cache type") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + + def test_batching_equivalence(self): + # need to disable the tril input mask + orig = self.model_tester.use_input_mask + self.model_tester.use_input_mask = False + super().test_batching_equivalence() + self.model_tester.use_input_mask = orig + + # essentially the same test in test_utils, just adjustment for rtol for this model + @pytest.mark.generate + def test_left_padding_compatibility(self): + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + # First, filter out models that don't support left padding + # - The model must have generative capabilities + if len(self.all_generative_model_classes) == 0: + self.skipTest(reason="No generative architecture available for this model.") + + # - The model must support padding + if not self.has_attentions: + self.skipTest(reason="This model doesn't support padding.") + + # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) + decoder_only_classes = [] + for model_class in self.all_generative_model_classes: + config, _ = self.prepare_config_and_inputs_for_generate() + if config.is_encoder_decoder: + continue + else: + decoder_only_classes.append(model_class) + if len(decoder_only_classes) == 0: + self.skipTest(reason="No decoder-only architecture available for this model.") + + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't + # added support for it yet. We skip these models for now. + has_encoder_attributes = any( + attr_name + for attr_name in config.to_dict().keys() + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" + ) + if has_encoder_attributes: + self.skipTest( + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." + ) + + # Then, test left-padding + def _prepare_model_kwargs(input_ids, attention_mask, signature): + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + model_kwargs["cache_position"] = cache_position + return model_kwargs + + for model_class in decoder_only_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + input_ids = inputs_dict["input_ids"] + + # - for left padding we absolutely need to use an all ones + # attention mask, so we do not use the one in inputs_dict + attention_mask = torch.ones_like(input_ids) + + model = model_class(config).to(torch_device).eval() + signature = inspect.signature(model.forward).parameters.keys() + + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) + next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] + + # With left-padding (length 32) + # can hardcode pad_token to be 0 as we'll do attn masking anyway + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 + ) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + + # They should result in very similar logits + torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-1) + + +@slow +@require_torch +class BambaModelIntegrationTest(unittest.TestCase): + model = None + tokenizer = None + # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # Depending on the hardware we get different logits / generations + cuda_compute_capability_major_version = None + + @classmethod + def setUpClass(cls): + model_id = "ibm-fms/Bamba-9B" + cls.model = BambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) + cls.tokenizer = AutoTokenizer.from_pretrained(model_id) + + # feels a bit forced to have to do this for the generation test + cls.tokenizer.pad_token_id = cls.model.config.pad_token_id + cls.tokenizer.padding_side = "left" + + if is_torch_available() and torch.cuda.is_available(): + # 8 is for A100 / A10 and 7 for T4 + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + + def test_simple_generate(self): + # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. + # + # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, + # considering differences in hardware processing and potential deviations in generated text. + EXPECTED_TEXTS = { + # 7: "", + 8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.", + # 9: """, + } + + self.model.to(torch_device) + + input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[ + "input_ids" + ].to(torch_device) + out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10) + output_sentence = self.tokenizer.decode(out[0, :]) + self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) + + # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist + if self.cuda_compute_capability_major_version == 8: + with torch.no_grad(): + logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits + + EXPECTED_LOGITS_NO_GRAD = torch.tensor( + [ + 149., 142., 146., 142., 143., 144., 142., 145., + 142., 146., 144., 146., 147., 147., 148., 145., + 147., 145., 145., 145., 145., 144., 144., 144., + 144., 145., 147., 146., 144., 144., 148., 147., + 148., 147., 147., 147., 146., 146., 148., 148. + ], dtype=torch.bfloat16) # fmt: skip + + torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1) + + def test_simple_batched_generate_with_padding(self): + # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. + # + # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, + # considering differences in hardware processing and potential deviations in generated text. + EXPECTED_TEXTS = { + 7: [], + 8: [ + "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", + "!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the", + ], + 9: [], + } + + self.model.to(torch_device) + + inputs = self.tokenizer( + ["Hey how are you doing on this lovely evening?", "I am late! I need to"], + padding=True, + return_tensors="pt", + ).to(torch_device) + out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) + output_sentences = self.tokenizer.batch_decode(out) + self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0]) + self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1]) + + # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist + if self.cuda_compute_capability_major_version == 8: + with torch.no_grad(): + logits = self.model(input_ids=inputs["input_ids"]).logits + + EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( + [ + 149., 142., 146., 142., 143., 144., 142., 145., + 142., 146., 144., 146., 147., 147., 148., 145., + 147., 145., 145., 145., 145., 144., 144., 144., + 144., 145., 147., 146., 144., 144., 148., 147., + 148., 147., 147., 147., 146., 146., 148., 148. + ], dtype=torch.bfloat16) # fmt: skip + + EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( + [ + 182., 178., 177., 174., 176., 176., 178., 178., + 177., 179., 176., 183., 180., 182., 179., 174., + 178., 176., 176., 175., 175., 175., 174., 173., + 174., 182., 180., 176., 177., 177., 180., 176., + 178., 177., 177., 175., 176., 177., 175., 177. + ], dtype=torch.bfloat16) # fmt: skip + + torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1) + torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 420d6e6a2475d1..116e26e7834f26 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -34,6 +34,9 @@ SPECIAL_CASES_TO_ALLOW = { # 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264). # periods and offsers are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`. + "BambaConfig": [ + "attn_layer_indices", + ], "JambaConfig": [ "max_position_embeddings", "attn_layer_offset", From d19b11f59b0cf337b445827228df796c6c335994 Mon Sep 17 00:00:00 2001 From: Tony Wu <28306721+tonywu71@users.noreply.github.com> Date: Thu, 19 Dec 2024 09:08:28 +0100 Subject: [PATCH 05/23] Fix documentation for ColPali (#35321) * docs: fix typo quickstart snippet in ColPali's model card * docs: clean the ColPali's model card * docs: make the `ColPaliForRetrieval`'s docstring more concise * docs: add missing bash command used to convert weights for `vidore/colpali-v1.3-hf` --- docs/source/en/model_doc/colpali.md | 21 +++++++------------ .../colpali/convert_colpali_weights_to_hf.py | 7 +++++++ .../models/colpali/modeling_colpali.py | 18 ++++++---------- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/docs/source/en/model_doc/colpali.md b/docs/source/en/model_doc/colpali.md index d47f0aa072262c..3f6b0cbc6613a9 100644 --- a/docs/source/en/model_doc/colpali.md +++ b/docs/source/en/model_doc/colpali.md @@ -18,29 +18,24 @@ rendered properly in your Markdown viewer. ## Overview -The ColPali model was proposed in [ColPali: Efficient Document Retrieval with Vision Language Models](https://doi.org/10.48550/arXiv.2407.01449) by **Manuel Faysse***, **Hugues Sibille***, **Tony Wu***, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo (* denotes equal contribution). +The *ColPali* model was proposed in [ColPali: Efficient Document Retrieval with Vision Language Models](https://doi.org/10.48550/arXiv.2407.01449) by **Manuel Faysse***, **Hugues Sibille***, **Tony Wu***, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo (* denotes equal contribution). Work lead by ILLUIN Technology. -With our new model *ColPali*, we propose to leverage VLMs to construct efficient multi-vector embeddings in the visual space for document retrieval. By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. We train the model to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method. +In our proposed *ColPali* approach, we leverage VLMs to construct efficient multi-vector embeddings directly from document images (“screenshots”) for document retrieval. We train the model to maximize the similarity between these document embeddings and the corresponding query embeddings, using the late interaction method introduced in ColBERT. -Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, ...) of a document. ColPali is also highly interpretable: similarity maps can be obtained between patches and query tokens. These maps highlight ColPali’s strong OCR capabilities and chart understanding. - -**Paper abstract:** - -> Documents are visually rich structures that convey information through text, but also figures, page layouts, tables, or even fonts. Since modern retrieval systems mainly rely on the textual information they extract from document pages to index documents -often through lengthy and brittle processes-, they struggle to exploit key visual cues efficiently. This limits their capabilities in many practical document retrieval applications such as Retrieval Augmented Generation (RAG). To benchmark current systems on visually rich document retrieval, we introduce the Visual Document Retrieval Benchmark *ViDoRe*, composed of various page-level retrieval tasks spanning multiple domains, languages, and practical settings. The inherent complexity and performance shortcomings of modern systems motivate a new concept; doing document retrieval by directly embedding the images of the document pages. We release *ColPali*, a Vision Language Model trained to produce high-quality multi-vector embeddings from images of document pages. Combined with a late interaction matching mechanism, *ColPali* largely outperforms modern document retrieval pipelines while being drastically simpler, faster and end-to-end trainable. -> -> We release models, data, code and benchmarks under open licenses at [https://huggingface.co/vidore](https://huggingface.co/vidore). +Using *ColPali* removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, etc.) of a document. ## Resources +- The *ColPali* arXiv paper can be found [here](https://doi.org/10.48550/arXiv.2407.01449). 📄 - The official blog post detailing ColPali can be found [here](https://huggingface.co/blog/manu/colpali). 📝 - The original model implementation code for the ColPali model and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎 -- Cookbooks for learning to use the transformers-native version of ColPali, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚 +- Cookbooks for learning to use the transformers-native version of *ColPali*, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚 This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) and [@yonigozlan](https://huggingface.co/yonigozlan). ## Usage -This example demonstrates how to use ColPali to embed both queries and images, calculate their similarity scores, and identify the most relevant matches. For a specific query, you can retrieve the top-k most similar images by selecting the ones with the highest similarity scores. +This example demonstrates how to use *ColPali* to embed both queries and images, calculate their similarity scores, and identify the most relevant matches. For a specific query, you can retrieve the top-k most similar images by selecting the ones with the highest similarity scores. ```python import torch @@ -74,8 +69,8 @@ batch_queries = processor(text=queries).to(model.device) # Forward pass with torch.no_grad(): - image_embeddings = model(**batch_images) - query_embeddings = model(**batch_queries) + image_embeddings = model(**batch_images).embeddings + query_embeddings = model(**batch_queries).embeddings # Score the queries against the images scores = processor.score_retrieval(query_embeddings, image_embeddings) diff --git a/src/transformers/models/colpali/convert_colpali_weights_to_hf.py b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py index 595974e0da1c3f..1b30f3f97acda3 100644 --- a/src/transformers/models/colpali/convert_colpali_weights_to_hf.py +++ b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py @@ -26,6 +26,13 @@ --original_vlm_name_or_path google/paligemma-3b-mix-448 \ --output_dir vidore/colpali-v1.2-hf-internal \ --push_to_hub + +python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \ + --model_id vidore/colpali-v1.3-merged \ + --revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \ + --original_vlm_name_or_path google/paligemma-3b-mix-448 \ + --output_dir vidore/colpali-v1.3-hf \ + --push_to_hub ``` """ diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index 8bfff814c83756..d84f29a3414f0f 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -159,19 +159,13 @@ class ColPaliForRetrievalOutput(ModelOutput): @add_start_docstrings( """ - ColPali leverages Vision Language Models (VLMs) to construct efficient multi-vector embeddings in the visual space for document retrieval. - By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. The model - is trained to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method. + In our proposed ColPali approach, we leverage VLMs to construct efficient multi-vector embeddings directly + from document images (“screenshots”) for document retrieval. We train the model to maximize the similarity + between these document embeddings and the corresponding query embeddings, using the late interaction method + introduced in ColBERT. - Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account - both the textual and visual content (layout, charts, ...) of a document. - - ColPali was introduced in the following paper: [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449). - - Resources: - - A blog post detailing ColPali, a vision retrieval model, can be found [here](https://huggingface.co/blog/manu/colpali). 📝 - - The code for using and training the original ColPali model and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎 - - Cookbooks for learning to use the Hf version of ColPali, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚 + Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a + single model that can take into account both the textual and visual content (layout, charts, etc.) of a document. """ ) class ColPaliForRetrieval(ColPaliPreTrainedModel): From 4592cc9e9851828cc9632438a0271bb082238360 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 19 Dec 2024 09:45:27 +0100 Subject: [PATCH 06/23] Update comment CI bot (#35323) * update * update --------- Co-authored-by: ydshieh --- .github/workflows/self-comment-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/self-comment-ci.yml b/.github/workflows/self-comment-ci.yml index d6ef0af9ff83b5..b344ecfd59527d 100644 --- a/.github/workflows/self-comment-ci.yml +++ b/.github/workflows/self-comment-ci.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-22.04 name: Get PR number # For security: only allow team members to run - if: contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez"]'), github.actor) + if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }} outputs: PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }} steps: From 56ff1e92fd3cb808c82f3fa8e79e236dbbea34d5 Mon Sep 17 00:00:00 2001 From: Peter Date: Thu, 19 Dec 2024 00:53:48 -0800 Subject: [PATCH 07/23] PaliGemma: Make sure to add to suffix if is present in `text` (#35201) Move suffix processing code to out of if statement --- .../models/paligemma/processing_paligemma.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index cb35aab66cba49..5783308f831541 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -287,11 +287,6 @@ def __call__( elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])): raise ValueError("images must be an image, list of images or list of list of images") - if suffix is not None and _is_str_or_image(suffix): - suffix = [suffix] - if suffix is not None: - suffix = [sfx + self.tokenizer.eos_token for sfx in suffix] - input_strings = [ build_string_from_input( prompt=prompt, @@ -314,6 +309,11 @@ def __call__( ) expanded_samples.append(expanded_sample) input_strings = [f"{sample}\n" for sample in expanded_samples] + + if suffix is not None and _is_str_or_image(suffix): + suffix = [suffix] + if suffix is not None: + suffix = [sfx + self.tokenizer.eos_token for sfx in suffix] pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] # max_length has to account for the image tokens From 667ed5635e6fd7e2df4fc23012746b1c0cbb7575 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 19 Dec 2024 08:03:35 -0500 Subject: [PATCH 08/23] Add ModernBERT to Transformers (#35158) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial cut of modernbert for transformers * small bug fixes * fixes * Update import * Use compiled mlp->mlp_norm to match research implementation * Propagate changes in modular to modeling * Replace duplicate attn_out_dropout in favor of attention_dropout cc @warner-benjamin let me know if the two should remain separate! * Update BOS to CLS and EOS to SEP Please confirm @warner-benjamin * Set default classifier bias to False, matching research repo * Update tie_word_embeddings description * Fix _init_weights for ForMaskedLM * Match base_model_prefix * Add compiled_head to match research repo outputs * Fix imports for ModernBertForMaskedLM * Just use "gelu" default outright for classifier * Fix config name typo: initalizer -> initializer * Remove some unused parameters in docstring. Still lots to edit there! * Compile the embeddings forward Not having this resulted in very slight differences - so small it wasn't even noticed for the base model, only for the large model. But the tiny difference for large propagated at the embedding layer through the rest of the model, leading to notable differences of ~0.0084 average per value, up to 0.2343 for the worst case. * Add drafts for ForSequenceClassification/ForTokenClassification * Add initial SDPA support (not exactly equivalent to FA2 yet!) During testing, FA2 and SDPA still differ by about 0.0098 per value in the token embeddings. It still predicts the correct mask fills, but I'd like to get it fully 1-1 if possible. * Only use attention dropout if training * Add initial eager attention support (also not equivalent to FA2 yet!) Frustratingly, I also can't get eager to be equivalent to FA2 (or sdpa), but it does get really close, i.e. avg ~0.010 difference per value. Especially if I use fp32 for both FA2&eager, avg ~0.0029 difference per value The fill-mask results are good with eager. * Add initial tests, output_attentions, output_hidden_states, prune_heads Tests are based on BERT, not all tests pass yet: 23 failed, 79 passed, 100 skipped * Remove kwargs from ModernBertForMaskedLM Disable sparse_prediction by default to match the normal HF, can be enabled via config * Remove/adjust/skip improper tests; warn if padding but no attn mask * Run formatting etc. * Run python utils/custom_init_isort.py * FlexAttention with unpadded sequences(matches FA2 within bf16 numerics) * Reformat init_weights based on review * self -> module in attention forwards * Remove if config.tie_word_embeddings * Reformat output projection on a different line * Remove pruning * Remove assert * Call contiguous() to simplify paths * Remove prune_qkv_linear_layer * Format code * Keep as kwargs, only use if needed * Remove unused codepaths & related config options * Remove 3d attn_mask test; fix token classification tuple output * Reorder: attention_mask above position_ids, fixes gradient checkpointing * Fix usage if no FA2 or torch v2.5+ * Make torch.compile/triton optional Should we rename 'compile'? It's a bit vague * Separate pooling options into separate functions (cls, mean) - cls as default * Simplify _pad_modernbert_output, remove unused labels path * Update tied weights to remove decoder.weight, simplify decoder loading * Adaptively set config.compile based on hf_device_map/device/resize, etc. * Update ModernBertConfig docstring * Satisfy some consistency checks, add unfinished docs * Only set compile to False if there's more than 1 device * Add docstrings for public ModernBert classes * Dont replace docstring returns - ends up being duplicate * Fix mistake in toctree * Reformat toctree * Patched FlexAttention, SDPA, Eager with Local Attention * Implement FA2 -> SDPA -> Eager attn_impl defaulting, crucial both to match the original performance, and to get the highest inference speed without requiring users to manually pick FA2 * Patch test edge case with Idefics3 not working with 'attn_implementation="sdpa"' * Repad all_hidden_states as well * rename config.compile to reference_compile * disable flex_attention since it crashes * Update modernbert.md * Using dtype min to mask in eager * Fully remove flex attention for now It's only compatible with the nightly torch 2.6, so we'll leave it be for now. It's also slower than eager/sdpa. Also, update compile -> reference_compile in one more case * Call contiguous to allow for .view() * Copyright 2020 -> 2024 Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update/simplify __init__ structure Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Remove "... if dropout_prob > 0 else identity" As dropout with 0.0 should be efficient like identity * re-use existing pad/unpad functions instead of creating new ones * remove flexattention method * Compute attention_mask and local_attention_mask once in modeling * Simplify sequence classification prediction heads, only CLS now Users can make custom heads if they feel like it Also removes the unnecessary pool parameter * Simplify module.training in eager attn * Also export ModernBertPreTrainedModel * Update the documentation with links to finetuning scripts * Explain local_attention_mask parameter in docstring * Simplify _autoset_attn_implementation, rely on super() * Keep "in" to initialize Prediction head Doublechecked with Benjamin that it's correct/what we used for pretraining * add back mean pooling * Use the pooling head in TokenClassification * update copyright * Reset config._attn_implementation_internal on failure * Allow optional attention_mask in ForMaskedLM head * fix failing run_slow tests * Add links to the paper * Remove unpad_no_grad, always pad/unpad without gradients * local_attention_mask -> sliding_window_mask * Revert "Use the pooling head in TokenClassification" This reverts commit 99c38badd1dbce01d7aef41095fbf2f5cce87279. There was no real motivation, no info on whether having this bigger head does anything useful. * Simplify pooling, 2 options via if-else --------- Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Co-authored-by: Tom Aarsen Co-authored-by: Said Taghadouini Co-authored-by: Benjamin Clavié Co-authored-by: Antoine Chaffin Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/modernbert.md | 91 ++ docs/source/en/perf_infer_gpu_one.md | 2 + src/transformers/__init__.py | 18 + src/transformers/loss/loss_utils.py | 17 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/auto/tokenization_auto.py | 1 + .../models/modernbert/__init__.py | 27 + .../modernbert/configuration_modernbert.py | 213 +++ .../models/modernbert/modeling_modernbert.py | 1322 +++++++++++++++ .../models/modernbert/modular_modernbert.py | 1452 +++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 35 + src/transformers/utils/import_utils.py | 6 +- tests/models/modernbert/__init__.py | 0 .../modernbert/test_modeling_modernbert.py | 367 +++++ tests/test_modeling_common.py | 9 +- 19 files changed, 3568 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/model_doc/modernbert.md create mode 100644 src/transformers/models/modernbert/__init__.py create mode 100644 src/transformers/models/modernbert/configuration_modernbert.py create mode 100644 src/transformers/models/modernbert/modeling_modernbert.py create mode 100644 src/transformers/models/modernbert/modular_modernbert.py create mode 100644 tests/models/modernbert/__init__.py create mode 100644 tests/models/modernbert/test_modeling_modernbert.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c30dfd3fbabc97..8138dd41d80c12 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -498,6 +498,8 @@ title: mLUKE - local: model_doc/mobilebert title: MobileBERT + - local: model_doc/modernbert + title: ModernBert - local: model_doc/mpnet title: MPNet - local: model_doc/mpt diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 0bd81e9d61be29..967049d89cbe12 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -232,6 +232,7 @@ Flax), PyTorch, and/or TensorFlow. | [MobileNetV2](model_doc/mobilenet_v2) | ✅ | ❌ | ❌ | | [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ | | [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ | +| [ModernBERT](model_doc/modernbert) | ✅ | ❌ | ❌ | | [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ | | [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ | | [MPT](model_doc/mpt) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md new file mode 100644 index 00000000000000..ab09f38ff12154 --- /dev/null +++ b/docs/source/en/model_doc/modernbert.md @@ -0,0 +1,91 @@ + + +# ModernBert + +
+ +Models + + +Paper page + +
+ +## Overview + +The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli. + +It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta). + +It builds on BERT and implements many modern architectural improvements which have been developed since its original release, such as: +- [Rotary Positional Embeddings](https://huggingface.co/blog/designing-positional-encoding) to support sequences of up to 8192 tokens. +- [Unpadding](https://arxiv.org/abs/2208.08124) to ensure no compute is wasted on padding tokens, speeding up processing time for batches with mixed-length sequences. +- [GeGLU](https://arxiv.org/abs/2002.05202) Replacing the original MLP layers with GeGLU layers, shown to improve performance. +- [Alternating Attention](https://arxiv.org/abs/2004.05150v2) where most attention layers employ a sliding window of 128 tokens, with Global Attention only used every 3 layers. +- [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up processing. +- A model designed following recent [The Case for Co-Designing Model Architectures with Hardware](https://arxiv.org/abs/2401.14489), ensuring maximum efficiency across inference GPUs. +- Modern training data scales (2 trillion tokens) and mixtures (including code ande math data) + +The abstract from the paper is the following: + +*Encoder-only transformer models such as BERT offer a great performance-size tradeoff for retrieval and classification tasks with respect to larger decoder-only models. Despite being the workhorse of numerous production pipelines, there have been limited Pareto improvements to BERT since its release. In this paper, we introduce ModernBERT, bringing modern model optimizations to encoder-only models and representing a major Pareto improvement over older encoders. Trained on 2 trillion tokens with a native 8192 sequence length, ModernBERT models exhibit state-of-the-art results on a large pool of evaluations encompassing diverse classification tasks and both single and multi-vector retrieval on different domains (including code). In addition to strong downstream performance, ModernBERT is also the most speed and memory efficient encoder and is designed for inference on common GPUs.* + +The original code can be found [here](https://github.com/answerdotai/modernbert). + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ModernBert. + + + +- A script on how to [finetune for text similarity or information retrieval with Sentence Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_st.py). 🌎 +- A script on how to [finetune for information retrieval with PyLate](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_pylate.py). 🌎 + + + +- [Masked language modeling task guide](../tasks/masked_language_modeling) + + +## ModernBertConfig + +[[autodoc]] ModernBertConfig + + + + +## ModernBertModel + +[[autodoc]] ModernBertModel + - forward + +## ModernBertForMaskedLM + +[[autodoc]] ModernBertForMaskedLM + - forward + +## ModernBertForSequenceClassification + +[[autodoc]] ModernBertForSequenceClassification + - forward + +## ModernBertForTokenClassification + +[[autodoc]] ModernBertForTokenClassification + - forward + + + diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 4f9cace5b8d30d..930f41b6fefba7 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -74,6 +74,7 @@ FlashAttention-2 is currently supported for the following architectures: * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) @@ -265,6 +266,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6a180a90bbbaa2..600d3d217fa8a9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -606,6 +606,7 @@ "models.mobilenet_v2": ["MobileNetV2Config"], "models.mobilevit": ["MobileViTConfig"], "models.mobilevitv2": ["MobileViTV2Config"], + "models.modernbert": ["ModernBertConfig"], "models.moshi": [ "MoshiConfig", "MoshiDepthConfig", @@ -2869,6 +2870,15 @@ "MobileViTV2PreTrainedModel", ] ) + _import_structure["models.modernbert"].extend( + [ + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", + "ModernBertModel", + "ModernBertPreTrainedModel", + ] + ) _import_structure["models.moshi"].extend( [ "MoshiForCausalLM", @@ -5565,6 +5575,7 @@ from .models.mobilevitv2 import ( MobileViTV2Config, ) + from .models.modernbert import ModernBertConfig from .models.moshi import ( MoshiConfig, MoshiDepthConfig, @@ -7556,6 +7567,13 @@ MobileViTV2Model, MobileViTV2PreTrainedModel, ) + from .models.modernbert import ( + ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertModel, + ModernBertPreTrainedModel, + ) from .models.moshi import ( MoshiForCausalLM, MoshiForConditionalGeneration, diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index efa23d24e360b4..7f6aaaa44264ca 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -47,6 +47,22 @@ def ForCausalLMLoss( return loss +def ForMaskedLMLoss( + logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs +): + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + + # Flatten the tokens + logits = logits.view(-1, vocab_size) + labels = labels.view(-1) + # Enable model parallelism + + labels = labels.to(logits.device) + loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs) + return loss + + def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): num_labels = config.num_labels if config.problem_type is None: @@ -101,6 +117,7 @@ def ForTokenClassification(logits, labels, config, **kwargs): LOSS_MAPPING = { "ForCausalLM": ForCausalLMLoss, + "ForMaskedLM": ForMaskedLMLoss, "ForQuestionAnswering": ForQuestionAnsweringLoss, "ForSequenceClassification": ForSequenceClassificationLoss, "ForTokenClassification": ForTokenClassification, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5b3c648428359d..7fcaddde704cf7 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -167,6 +167,7 @@ mobilenet_v2, mobilevit, mobilevitv2, + modernbert, moshi, mpnet, mpt, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 8aba0e75b2690b..69ce8efa10c76c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -187,6 +187,7 @@ ("mobilenet_v2", "MobileNetV2Config"), ("mobilevit", "MobileViTConfig"), ("mobilevitv2", "MobileViTV2Config"), + ("modernbert", "ModernBertConfig"), ("moshi", "MoshiConfig"), ("mpnet", "MPNetConfig"), ("mpt", "MptConfig"), @@ -510,6 +511,7 @@ ("mobilenet_v2", "MobileNetV2"), ("mobilevit", "MobileViT"), ("mobilevitv2", "MobileViTV2"), + ("modernbert", "ModernBERT"), ("moshi", "Moshi"), ("mpnet", "MPNet"), ("mpt", "MPT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 770e4ea0775f76..e8a2dece432476 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -176,6 +176,7 @@ ("mobilenet_v2", "MobileNetV2Model"), ("mobilevit", "MobileViTModel"), ("mobilevitv2", "MobileViTV2Model"), + ("modernbert", "ModernBertModel"), ("moshi", "MoshiModel"), ("mpnet", "MPNetModel"), ("mpt", "MptModel"), @@ -838,6 +839,7 @@ ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForMaskedLM"), ("mobilebert", "MobileBertForMaskedLM"), + ("modernbert", "ModernBertForMaskedLM"), ("mpnet", "MPNetForMaskedLM"), ("mra", "MraForMaskedLM"), ("mvp", "MvpForConditionalGeneration"), @@ -992,6 +994,7 @@ ("mistral", "MistralForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), + ("modernbert", "ModernBertForSequenceClassification"), ("mpnet", "MPNetForSequenceClassification"), ("mpt", "MptForSequenceClassification"), ("mra", "MraForSequenceClassification"), @@ -1178,6 +1181,7 @@ ("mistral", "MistralForTokenClassification"), ("mixtral", "MixtralForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), + ("modernbert", "ModernBertForTokenClassification"), ("mpnet", "MPNetForTokenClassification"), ("mpt", "MptForTokenClassification"), ("mra", "MraForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 1cdebde8cd904f..350c230f142c15 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -313,6 +313,7 @@ ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), + ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/modernbert/__init__.py b/src/transformers/models/modernbert/__init__.py new file mode 100644 index 00000000000000..18317742981909 --- /dev/null +++ b/src/transformers/models/modernbert/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_modernbert import * + from .modeling_modernbert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py new file mode 100644 index 00000000000000..13e9edf067efc4 --- /dev/null +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -0,0 +1,213 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_modernbert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +from ...configuration_utils import PretrainedConfig + + +class ModernBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ModernBERT-base. + e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50368): + Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ModernBertModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 1152): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 22): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` + if not specified. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): + The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + norm_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the normalization layers. + pad_token_id (`int`, *optional*, defaults to 50283): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 50282): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 50281): + Beginning of stream token id. + cls_token_id (`int`, *optional*, defaults to 50281): + Classification token id. + sep_token_id (`int`, *optional*, defaults to 50282): + Separation token id. + global_rope_theta (`float`, *optional*, defaults to 160000.0): + The base period of the global RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + global_attn_every_n_layers (`int`, *optional*, defaults to 3): + The number of layers between global attention layers. + local_attention (`int`, *optional*, defaults to 128): + The window size for local attention. + local_rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the local RoPE embeddings. + embedding_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layers. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + decoder_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the decoder layers. + classifier_pooling (`str`, *optional*, defaults to `"cls"`): + The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the + CLS token doesn't attend to all tokens on long sequences. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the classifier. + classifier_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the classifier. + classifier_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function for the classifier. + deterministic_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. + sparse_prediction (`bool`, *optional*, defaults to `False`): + Whether to use sparse prediction for the masked language model instead of returning the full dense logits. + sparse_pred_ignore_index (`int`, *optional*, defaults to -100): + The index to ignore for the sparse prediction. + reference_compile (`bool`, *optional*): + Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of + the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not + shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may + be faster in some scenarios. + + Examples: + + ```python + >>> from transformers import ModernBertModel, ModernBertConfig + + >>> # Initializing a ModernBert style configuration + >>> configuration = ModernBertConfig() + + >>> # Initializing a model from the modernbert-base style configuration + >>> model = ModernBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "modernbert" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50368, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=22, + num_attention_heads=12, + hidden_activation="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + initializer_cutoff_factor=2.0, + norm_eps=1e-5, + norm_bias=False, + pad_token_id=50283, + eos_token_id=50282, + bos_token_id=50281, + cls_token_id=50281, + sep_token_id=50282, + global_rope_theta=160000.0, + attention_bias=False, + attention_dropout=0.0, + global_attn_every_n_layers=3, + local_attention=128, + local_rope_theta=10000.0, + embedding_dropout=0.0, + mlp_bias=False, + mlp_dropout=0.0, + decoder_bias=True, + classifier_pooling: Literal["cls", "mean"] = "cls", + classifier_dropout=0.0, + classifier_bias=False, + classifier_activation="gelu", + deterministic_flash_attn=False, + sparse_prediction=False, + sparse_pred_ignore_index=-100, + reference_compile=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initializer_cutoff_factor = initializer_cutoff_factor + self.norm_eps = norm_eps + self.norm_bias = norm_bias + self.global_rope_theta = global_rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attention = local_attention + self.local_rope_theta = local_rope_theta + self.embedding_dropout = embedding_dropout + self.mlp_bias = mlp_bias + self.mlp_dropout = mlp_dropout + self.decoder_bias = decoder_bias + self.classifier_pooling = classifier_pooling + self.classifier_dropout = classifier_dropout + self.classifier_bias = classifier_bias + self.classifier_activation = classifier_activation + self.deterministic_flash_attn = deterministic_flash_attn + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.reference_compile = reference_compile + + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + + +__all__ = ["ModernBertConfig"] diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py new file mode 100644 index 00000000000000..db8d98893f96fe --- /dev/null +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -0,0 +1,1322 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_modernbert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, +) +from ...utils.import_utils import is_triton_available +from .configuration_modernbert import ModernBertConfig + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary +else: + RotaryEmbedding = object + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" +_CONFIG_FOR_DOC = "ModernBertConfig" + + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, 3, nheads, headdim) + qkv = qkv.contiguous() + total_nnz, _three, _nheads, headdim = qkv.shape + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return qkv + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + do = do.contiguous() + total_nnz, _three, _nheads, headdim = do.shape + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """ + The rotary position embeddings applied directly to unpadded sequences. + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache wll be recomputed during the forward pass. + """ + super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Apply rotary embedding *inplace* to qkv. + qkv: (total_nnz, 3, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + + qkv = apply_rotary_unpadded( + qkv, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return qkv + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + + +class ModernBertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) + + @torch.compile(dynamic=True) + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + + +class ModernBertMLP(nn.Module): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + output_attentions: Optional[bool] = False, + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = module.head_dim**-0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bs, -1, dim) + if output_attentions: + return (attn_output, attn_weights) + return (attn_output,) + + +def flash_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + max_seqlen: int, + local_attention: Tuple[int, int], + bs: int, + dim: int, + target_dtype: torch.dtype = torch.bfloat16, + **_kwargs, +) -> Tuple[torch.Tensor]: + # (total_seqlen, 3, nheads, headdim) + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(target_dtype) + + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + return (attn.view(bs, dim),) + + +def sdpa_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) + attn_output = attn_output.view(bs, -1, dim) + return (attn_output,) + + +MODERNBERT_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. + If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, + which requires padding and unpadding inputs, adding some overhead. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + max_position_embeddings = config.max_position_embeddings + if self.local_attention != (-1, -1): + if config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + max_position_embeddings = config.local_attention + + if config._attn_implementation == "flash_attention_2": + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta + ) + else: + self.rotary_emb = ModernBertRotaryEmbedding( + dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta + ) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> torch.Tensor: + qkv = self.Wqkv(hidden_states) + + bs = hidden_states.shape[0] + if self.config._attn_implementation == "flash_attention_2": + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + else: + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) + + attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +class ModernBertEncoderLayer(nn.Module): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + if layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + @torch.compile(dynamic=True) + def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(self.mlp_norm(hidden_states)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs = self.attn( + self.attn_norm(hidden_states), + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + attn_outputs[0] + mlp_output = ( + self.compiled_mlp(hidden_states) + if self.config.reference_compile + else self.mlp(self.mlp_norm(hidden_states)) + ) + hidden_states = hidden_states + mlp_output + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +MODERNBERT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ModernBertConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertPreTrainedModel(PreTrainedModel): + config_class = ModernBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = False + + def _init_weights(self, module: nn.Module): + cutoff_factor = self.config.initializer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + def init_weight(module: nn.Module, std: float): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + stds = { + "in": self.config.initializer_range, + "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), + "embedding": self.config.initializer_range, + "final_out": self.config.hidden_size**-0.5, + } + + if isinstance(module, ModernBertEmbeddings): + init_weight(module.tok_embeddings, stds["embedding"]) + elif isinstance(module, ModernBertMLP): + init_weight(module.Wi, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertAttention): + init_weight(module.Wqkv, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertPredictionHead): + init_weight(module.dense, stds["in"]) + elif isinstance(module, ModernBertPoolingHead): + init_weight(module.dense, stds["out"]) + elif isinstance(module, ModernBertForMaskedLM): + init_weight(module.decoder, stds["out"]) + elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): + init_weight(module.classifier, stds["final_out"]) + + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + # If the user didn't specify anything, try to use flash_attention_2 if available. + # Otherwise we fall back to the default SDPA -> Eager from the super() method. + if config._attn_implementation_internal is None: + config._attn_implementation_internal = "flash_attention_2" + try: + return cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + except (ValueError, ImportError): + config._attn_implementation_internal = None + return super()._autoset_attn_implementation( + config, + use_flash_attention_2=use_flash_attention_2, + torch_dtype=torch_dtype, + device_map=device_map, + check_device_map=check_device_map, + ) + + def _maybe_set_compile(self): + if self.config.reference_compile is False: + return + + if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: + if self.config.reference_compile: + logger.warning_once( + "If `accelerate` split the model across devices, `torch.compile` will not work. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.device.type == "mps": + if self.config.reference_compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.config.reference_compile is None: + self.config.reference_compile = is_triton_available() + + def resize_token_embeddings(self, *args, **kwargs): + model_embeds = super().resize_token_embeddings(*args, **kwargs) + + if self.config.reference_compile in {True, None}: + if self.config.reference_compile: + logger.warning_once( + "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + return model_embeds + + +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, +) -> torch.Tensor: + """ + Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + return padded_inputs + + +MODERNBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch. Used to pad the output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences. Used to pad the output tensors. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertModel(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = nn.ModuleList( + [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + ) + self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + self._maybe_set_compile() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if batch_size is None and seq_len is None: + batch_size, seq_len = input_ids.shape[:2] + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + + repad = False + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + repad = True + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + if position_ids is None: + position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + + attention_mask, sliding_window_mask = self._update_attention_mask( + attention_mask, output_attentions=output_attentions + ) + + hidden_states = self.embeddings(input_ids) + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + sliding_window_mask, + position_ids, + cu_seqlens, + max_seqlen, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions and len(layer_outputs) > 1: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.final_norm(hidden_states) + + if repad: + hidden_states = _pad_modernbert_output( + inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len + ) + if all_hidden_states is not None: + all_hidden_states = tuple( + _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) + for hs in all_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: + if output_attentions: + if self.config._attn_implementation == "sdpa": + logger.warning_once( + "Outputting attentions is only supported with the 'eager' attention implementation, " + 'not with "sdpa". Falling back to `attn_implementation="eager"`.' + ) + self.config._attn_implementation = "eager" + elif self.config._attn_implementation != "eager": + logger.warning_once( + "Outputting attentions is only supported with the eager attention implementation, " + f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' + " Setting `output_attentions=False`." + ) + + global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) + + # Create position indices + rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = ( + (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + ) + # Combine with existing mask + sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) + + return global_attention_mask, sliding_window_mask + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +@add_start_docstrings( + "The ModernBert Model with a decoder head on top that is used for masked language modeling.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForMaskedLM(ModernBertPreTrainedModel): + _tied_weights_keys = ["decoder.weight"] + + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) + + self.sparse_prediction = self.config.sparse_prediction + self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.decoder = new_embeddings + + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + batch_size, seq_len = input_ids.shape[:2] + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.sparse_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + last_hidden_state = last_hidden_state.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.sparse_pred_ignore_index + last_hidden_state = last_hidden_state[mask_tokens] + labels = labels[mask_tokens] + + logits = ( + self.compiled_head(last_hidden_state) + if self.config.reference_compile + else self.decoder(self.head(last_hidden_state)) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) + + if self.config._attn_implementation == "flash_attention_2": + with torch.no_grad(): + logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class ModernBertPoolingHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = torch.nn.Dropout(config.classifier_dropout) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + if self.config.classifier_pooling == "cls": + hidden_states = hidden_states[:, 0] + elif self.config.classifier_pooling == "mean": + hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + return self.drop(self.norm(self.act(self.dense(hidden_states)))) + + +@add_start_docstrings( + "The ModernBert Model with a sequence classification head on top that performs pooling.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForSequenceClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPoolingHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + pooled_output = self.head(last_hidden_state, attention_mask) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForTokenClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.drop = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ModernBertModel", + "ModernBertPreTrainedModel", + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", +] diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py new file mode 100644 index 00000000000000..3c23f9178b1b51 --- /dev/null +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -0,0 +1,1452 @@ +# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Literal, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, +) +from ...utils.import_utils import is_triton_available +from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary +else: + RotaryEmbedding = object + +_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" +_CONFIG_FOR_DOC = "ModernBertConfig" + +logger = logging.get_logger(__name__) + + +class ModernBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ModernBERT-base. + e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50368): + Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ModernBertModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 1152): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 22): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu"` + if not specified. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_cutoff_factor (`float`, *optional*, defaults to 2.0): + The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + norm_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the normalization layers. + pad_token_id (`int`, *optional*, defaults to 50283): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 50282): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 50281): + Beginning of stream token id. + cls_token_id (`int`, *optional*, defaults to 50281): + Classification token id. + sep_token_id (`int`, *optional*, defaults to 50282): + Separation token id. + global_rope_theta (`float`, *optional*, defaults to 160000.0): + The base period of the global RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + global_attn_every_n_layers (`int`, *optional*, defaults to 3): + The number of layers between global attention layers. + local_attention (`int`, *optional*, defaults to 128): + The window size for local attention. + local_rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the local RoPE embeddings. + embedding_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the MLP layers. + mlp_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the MLP layers. + decoder_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the decoder layers. + classifier_pooling (`str`, *optional*, defaults to `"cls"`): + The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the + CLS token doesn't attend to all tokens on long sequences. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the classifier. + classifier_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the classifier. + classifier_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function for the classifier. + deterministic_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic. + sparse_prediction (`bool`, *optional*, defaults to `False`): + Whether to use sparse prediction for the masked language model instead of returning the full dense logits. + sparse_pred_ignore_index (`int`, *optional*, defaults to -100): + The index to ignore for the sparse prediction. + reference_compile (`bool`, *optional*): + Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of + the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not + shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may + be faster in some scenarios. + + Examples: + + ```python + >>> from transformers import ModernBertModel, ModernBertConfig + + >>> # Initializing a ModernBert style configuration + >>> configuration = ModernBertConfig() + + >>> # Initializing a model from the modernbert-base style configuration + >>> model = ModernBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "modernbert" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50368, + hidden_size=768, + intermediate_size=1152, + num_hidden_layers=22, + num_attention_heads=12, + hidden_activation="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + initializer_cutoff_factor=2.0, + norm_eps=1e-5, + norm_bias=False, + pad_token_id=50283, + eos_token_id=50282, + bos_token_id=50281, + cls_token_id=50281, + sep_token_id=50282, + global_rope_theta=160000.0, + attention_bias=False, + attention_dropout=0.0, + global_attn_every_n_layers=3, + local_attention=128, + local_rope_theta=10000.0, + embedding_dropout=0.0, + mlp_bias=False, + mlp_dropout=0.0, + decoder_bias=True, + classifier_pooling: Literal["cls", "mean"] = "cls", + classifier_dropout=0.0, + classifier_bias=False, + classifier_activation="gelu", + deterministic_flash_attn=False, + sparse_prediction=False, + sparse_pred_ignore_index=-100, + reference_compile=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.initializer_cutoff_factor = initializer_cutoff_factor + self.norm_eps = norm_eps + self.norm_bias = norm_bias + self.global_rope_theta = global_rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.global_attn_every_n_layers = global_attn_every_n_layers + self.local_attention = local_attention + self.local_rope_theta = local_rope_theta + self.embedding_dropout = embedding_dropout + self.mlp_bias = mlp_bias + self.mlp_dropout = mlp_dropout + self.decoder_bias = decoder_bias + self.classifier_pooling = classifier_pooling + self.classifier_dropout = classifier_dropout + self.classifier_bias = classifier_bias + self.classifier_activation = classifier_activation + self.deterministic_flash_attn = deterministic_flash_attn + self.sparse_prediction = sparse_prediction + self.sparse_pred_ignore_index = sparse_pred_ignore_index + self.reference_compile = reference_compile + + if self.classifier_pooling not in ["cls", "mean"]: + raise ValueError( + f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' + ) + + +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, +) -> torch.Tensor: + """ + Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + return padded_inputs + + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, 3, nheads, headdim) + qkv = qkv.contiguous() + total_nnz, _three, _nheads, headdim = qkv.shape + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d") + qk = qkv[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + qk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return qkv + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + do = do.contiguous() + total_nnz, _three, _nheads, headdim = do.shape + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions, + # we get the same tensor + dqk = do[:, :2].view(total_nnz, -1, headdim) + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + qkv, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """ + The rotary position embeddings applied directly to unpadded sequences. + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache wll be recomputed during the forward pass. + """ + super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Apply rotary embedding *inplace* to qkv. + qkv: (total_nnz, 3, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + + qkv = apply_rotary_unpadded( + qkv, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return qkv + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + + +class ModernBertEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = nn.Dropout(config.embedding_dropout) + + @torch.compile(dynamic=True) + def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.drop(self.norm(self.tok_embeddings(input_ids))) + + def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + hidden_states = ( + self.compiled_embeddings(input_ids) + if self.config.reference_compile + else self.drop(self.norm(self.tok_embeddings(input_ids))) + ) + return hidden_states + + +class ModernBertMLP(nn.Module): + """Applies the GLU at the end of each ModernBERT layer. + + Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` + and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. + """ + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) + self.act = ACT2FN[config.hidden_activation] + self.drop = nn.Dropout(config.mlp_dropout) + self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.drop(self.act(input) * gate)) + + +class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): + pass + + +def eager_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + output_attentions: Optional[bool] = False, + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + scale = module.head_dim**-0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bs, -1, dim) + if output_attentions: + return (attn_output, attn_weights) + return (attn_output,) + + +def flash_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + cu_seqlens: torch.Tensor, + max_seqlen: int, + local_attention: Tuple[int, int], + bs: int, + dim: int, + target_dtype: torch.dtype = torch.bfloat16, + **_kwargs, +) -> Tuple[torch.Tensor]: + # (total_seqlen, 3, nheads, headdim) + qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = qkv.dtype + qkv = qkv.to(target_dtype) + + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=module.attention_dropout if module.training else 0.0, + deterministic=module.deterministic_flash_attn, + window_size=local_attention, + ) + return (attn.view(bs, dim),) + + +def sdpa_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) + attn_output = attn_output.view(bs, -1, dim) + return (attn_output,) + + +MODERNBERT_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences. + + If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. + If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, + which requires padding and unpadding inputs, adding some overhead. + + See `forward` method for additional details. + """ + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + max_position_embeddings = config.max_position_embeddings + if self.local_attention != (-1, -1): + if config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + max_position_embeddings = config.local_attention + + if config._attn_implementation == "flash_attention_2": + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta + ) + else: + self.rotary_emb = ModernBertRotaryEmbedding( + dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta + ) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> torch.Tensor: + qkv = self.Wqkv(hidden_states) + + bs = hidden_states.shape[0] + if self.config._attn_implementation == "flash_attention_2": + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + else: + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) + + attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +class ModernBertEncoderLayer(nn.Module): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + if layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + @torch.compile(dynamic=True) + def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(self.mlp_norm(hidden_states)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs = self.attn( + self.attn_norm(hidden_states), + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + attn_outputs[0] + mlp_output = ( + self.compiled_mlp(hidden_states) + if self.config.reference_compile + else self.mlp(self.mlp_norm(hidden_states)) + ) + hidden_states = hidden_states + mlp_output + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + + +MODERNBERT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ModernBertConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertPreTrainedModel(PreTrainedModel): + config_class = ModernBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = False + + def _init_weights(self, module: nn.Module): + cutoff_factor = self.config.initializer_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + def init_weight(module: nn.Module, std: float): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + stds = { + "in": self.config.initializer_range, + "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), + "embedding": self.config.initializer_range, + "final_out": self.config.hidden_size**-0.5, + } + + if isinstance(module, ModernBertEmbeddings): + init_weight(module.tok_embeddings, stds["embedding"]) + elif isinstance(module, ModernBertMLP): + init_weight(module.Wi, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertAttention): + init_weight(module.Wqkv, stds["in"]) + init_weight(module.Wo, stds["out"]) + elif isinstance(module, ModernBertPredictionHead): + init_weight(module.dense, stds["in"]) + elif isinstance(module, ModernBertPoolingHead): + init_weight(module.dense, stds["out"]) + elif isinstance(module, ModernBertForMaskedLM): + init_weight(module.decoder, stds["out"]) + elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): + init_weight(module.classifier, stds["final_out"]) + + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + # If the user didn't specify anything, try to use flash_attention_2 if available. + # Otherwise we fall back to the default SDPA -> Eager from the super() method. + if config._attn_implementation_internal is None: + config._attn_implementation_internal = "flash_attention_2" + try: + return cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + except (ValueError, ImportError): + config._attn_implementation_internal = None + return super()._autoset_attn_implementation( + config, + use_flash_attention_2=use_flash_attention_2, + torch_dtype=torch_dtype, + device_map=device_map, + check_device_map=check_device_map, + ) + + def _maybe_set_compile(self): + if self.config.reference_compile is False: + return + + if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: + if self.config.reference_compile: + logger.warning_once( + "If `accelerate` split the model across devices, `torch.compile` will not work. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.device.type == "mps": + if self.config.reference_compile: + logger.warning_once( + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " + "Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + if self.config.reference_compile is None: + self.config.reference_compile = is_triton_available() + + def resize_token_embeddings(self, *args, **kwargs): + model_embeds = super().resize_token_embeddings(*args, **kwargs) + + if self.config.reference_compile in {True, None}: + if self.config.reference_compile: + logger.warning_once( + "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." + ) + self.config.reference_compile = False + + return model_embeds + + +MODERNBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch. Used to pad the output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences. Used to pad the output tensors. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ModernBert Model outputting raw hidden-states without any specific head on top.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertModel(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.layers = nn.ModuleList( + [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] + ) + self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + self._maybe_set_compile() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + + if batch_size is None and seq_len is None: + batch_size, seq_len = input_ids.shape[:2] + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + + repad = False + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + repad = True + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask + ) + else: + if position_ids is None: + position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + + attention_mask, sliding_window_mask = self._update_attention_mask( + attention_mask, output_attentions=output_attentions + ) + + hidden_states = self.embeddings(input_ids) + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + sliding_window_mask, + position_ids, + cu_seqlens, + max_seqlen, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions and len(layer_outputs) > 1: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.final_norm(hidden_states) + + if repad: + hidden_states = _pad_modernbert_output( + inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len + ) + if all_hidden_states is not None: + all_hidden_states = tuple( + _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) + for hs in all_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: + if output_attentions: + if self.config._attn_implementation == "sdpa": + logger.warning_once( + "Outputting attentions is only supported with the 'eager' attention implementation, " + 'not with "sdpa". Falling back to `attn_implementation="eager"`.' + ) + self.config._attn_implementation = "eager" + elif self.config._attn_implementation != "eager": + logger.warning_once( + "Outputting attentions is only supported with the eager attention implementation, " + f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' + " Setting `output_attentions=False`." + ) + + global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) + + # Create position indices + rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = ( + (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + ) + # Combine with existing mask + sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) + + return global_attention_mask, sliding_window_mask + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +@add_start_docstrings( + "The ModernBert Model with a decoder head on top that is used for masked language modeling.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForMaskedLM(ModernBertPreTrainedModel): + _tied_weights_keys = ["decoder.weight"] + + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) + + self.sparse_prediction = self.config.sparse_prediction + self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear): + self.decoder = new_embeddings + + @torch.compile(dynamic=True) + def compiled_head(self, output: torch.Tensor) -> torch.Tensor: + return self.decoder(self.head(output)) + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + if self.config._attn_implementation == "flash_attention_2": + if indices is None and cu_seqlens is None and max_seqlen is None: + batch_size, seq_len = input_ids.shape[:2] + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool) + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.sparse_prediction and labels is not None: + # flatten labels and output first + labels = labels.view(-1) + last_hidden_state = last_hidden_state.view(labels.shape[0], -1) + + # then filter out the non-masked tokens + mask_tokens = labels != self.sparse_pred_ignore_index + last_hidden_state = last_hidden_state[mask_tokens] + labels = labels[mask_tokens] + + logits = ( + self.compiled_head(last_hidden_state) + if self.config.reference_compile + else self.decoder(self.head(last_hidden_state)) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) + + if self.config._attn_implementation == "flash_attention_2": + with torch.no_grad(): + logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class ModernBertPoolingHead(nn.Module): + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.drop = torch.nn.Dropout(config.classifier_dropout) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + if self.config.classifier_pooling == "cls": + hidden_states = hidden_states[:, 0] + elif self.config.classifier_pooling == "mean": + hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + return self.drop(self.norm(self.act(self.dense(hidden_states)))) + + +@add_start_docstrings( + "The ModernBert Model with a sequence classification head on top that performs pooling.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForSequenceClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPoolingHead(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + pooled_output = self.head(last_hidden_state, attention_mask) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.", + MODERNBERT_START_DOCSTRING, +) +class ModernBertForTokenClassification(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.drop = nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ModernBertConfig", + "ModernBertModel", + "ModernBertPreTrainedModel", + "ModernBertForMaskedLM", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", +] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c9a49d737d092e..e3463461ea07e5 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6418,6 +6418,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ModernBertForMaskedLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ModernBertPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MoshiForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 32a647594741dd..92823a4ee016c3 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -192,7 +192,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _tiktoken_available = _is_package_available("tiktoken") _blobfile_available = _is_package_available("blobfile") _liger_kernel_available = _is_package_available("liger_kernel") - +_triton_available = _is_package_available("triton") _torch_version = "N/A" _torch_available = False @@ -1243,6 +1243,10 @@ def is_liger_kernel_available(): return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") +def is_triton_available(): + return _triton_available + + # docstyle-ignore AV_IMPORT_ERROR = """ {0} requires the PyAv library but it was not found in your environment. You can install it with: diff --git a/tests/models/modernbert/__init__.py b/tests/models/modernbert/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py new file mode 100644 index 00000000000000..4fce0cd86352f0 --- /dev/null +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -0,0 +1,367 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest + +import pytest + +from transformers import ModernBertConfig, is_torch_available +from transformers.models.auto import get_values +from transformers.testing_utils import ( + CaptureLogger, + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, + ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertModel, + logging, + ) + + +class ModernBertModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + pad_token_id=0, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_activation="gelu", + mlp_dropout=0.0, + attention_dropout=0.0, + embedding_dropout=0.0, + classifier_dropout=0.0, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.pad_token_id = pad_token_id + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.mlp_dropout = mlp_dropout + self.attention_dropout = attention_dropout + self.embedding_dropout = embedding_dropout + self.classifier_dropout = classifier_dropout + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + """ + Returns a tiny configuration by default. + """ + config = ModernBertConfig( + vocab_size=self.vocab_size, + pad_token_id=self.pad_token_id, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_activation=self.hidden_activation, + mlp_dropout=self.mlp_dropout, + attention_dropout=self.attention_dropout, + embedding_dropout=self.embedding_dropout, + classifier_dropout=self.classifier_dropout, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + if test := os.environ.get("PYTEST_CURRENT_TEST", False): + test_name = test.split(":")[-1].split(" ")[0] + + # If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error + # that compilation doesn't work. Users can then set compile=False when loading the model, + # much like here. We're testing whether it works once they've done that. + if test_name == "test_retain_grad_hidden_states_attentions": + config.reference_compile = False + # Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager + # as the others don't support outputted attentions + if test_name in ( + "test_attention_outputs", + "test_hidden_states_output", + "test_retain_grad_hidden_states_attentions", + ): + config._attn_implementation = "eager" + return config + + def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = ModernBertModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_masked_lm( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = ModernBertForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = ModernBertForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = ModernBertForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + test_torchscript = False + + all_model_classes = ( + ( + ModernBertModel, + ModernBertForMaskedLM, + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = () + pipeline_model_mapping = ( + { + "feature-extraction": ModernBertModel, + "fill-mask": ModernBertForMaskedLM, + "text-classification": ModernBertForSequenceClassification, + "token-classification": ModernBertForTokenClassification, + "zero-shot": ModernBertForSequenceClassification, + } + if is_torch_available() + else {} + ) + fx_compatible = False + test_head_masking = False + test_pruning = False + model_split_percents = [0.5, 0.8, 0.9] + + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if inputs_dict.get("output_attentions", False): + inputs_dict["output_attentions"] = True + + if return_labels: + if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + inputs_dict["next_sentence_label"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + + def setUp(self): + self.model_tester = ModernBertModelTester(self) + self.config_tester = ConfigTester(self, config_class=ModernBertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + # The classifier.weight from ModernBertForSequenceClassification and ModernBertForTokenClassification + # are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init + if param.requires_grad and not ( + name == "classifier.weight" + and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification] + ): + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @unittest.skip("ModernBert doesn't use `inputs_embeds` as input.") + def test_inputs_embeds(self): + pass + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + def test_for_warning_if_padding_and_no_attention_mask(self): + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.model_tester.prepare_config_and_inputs() + + # Set pad tokens in the input_ids + input_ids[0, 0] = config.pad_token_id + + # Check for warnings if the attention_mask is missing. + logger = logging.get_logger("transformers.modeling_utils") + # clear cache so we can test the warning is emitted (from `warning_once`). + logger.warning_once.cache_clear() + + with CaptureLogger(logger) as cl: + model = ModernBertModel(config=config) + model.to(torch_device) + model.eval() + model(input_ids, attention_mask=None) + self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out) + + @unittest.skip("ModernBert doesn't use separate classes for SDPA, but a function instead.") + def test_sdpa_can_dispatch_non_composite_models(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "google-bert/bert-base-uncased" + model = ModernBertModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="ModernBert flash attention does not support right padding") + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_conversion(self): + self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.") + + +@require_torch +class ModernBertModelIntegrationTest(unittest.TestCase): + """ + These still need to be written, once public models are available. + """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1d7e995f80756c..5f053c20ff7938 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3457,6 +3457,8 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): "Data2VecAudioForSequenceClassification", "UniSpeechForSequenceClassification", "PvtForImageClassification", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", "TimmWrapperForImageClassification", ] special_param_names = [ @@ -4042,7 +4044,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + try: + model_sdpa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch_dtype, attn_implementation="sdpa" + ) + except ValueError: + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) model_eager = model_class.from_pretrained( From 1fa807fa63d1aa9409fb1ae0cbb7583960e5ea98 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:05:25 +0100 Subject: [PATCH 09/23] Fix some fa2 tests (#35340) * remove fa2 test * remove other failing tests * style --- tests/models/granite/test_modeling_granite.py | 29 ----------------- .../granitemoe/test_modeling_granitemoe.py | 29 ----------------- tests/models/llama/test_modeling_llama.py | 31 ------------------- tests/test_modeling_common.py | 30 ------------------ 4 files changed, 119 deletions(-) diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py index 60eb964927278a..686544825c3551 100644 --- a/tests/models/granite/test_modeling_granite.py +++ b/tests/models/granite/test_modeling_granite.py @@ -14,14 +14,12 @@ # limitations under the License. """Testing suite for the PyTorch Granite model.""" -import tempfile import unittest from parameterized import parameterized from transformers import GraniteConfig, is_torch_available, set_seed from transformers.testing_utils import ( - require_flash_attn, require_read_token, require_torch, require_torch_gpu, @@ -417,33 +415,6 @@ def test_model_rope_scaling(self): with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - @require_flash_attn - @require_torch_gpu - @slow - def test_use_flash_attention_2_true(self): - """ - NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended. - """ - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - with tempfile.TemporaryDirectory() as tmp_dir: - model = model_class(config) - model.save_pretrained(tmp_dir) - - new_model = GraniteForCausalLM.from_pretrained( - tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16 - ).to("cuda") - - self.assertTrue(new_model.config._attn_implementation == "flash_attention_2") - - has_flash = False - for name, submodule in new_model.named_modules(): - if "FlashAttention" in submodule.__class__.__name__: - has_flash = True - break - if not has_flash: - raise ValueError("The flash model should have flash attention layers") - @require_torch_gpu class GraniteIntegrationTest(unittest.TestCase): diff --git a/tests/models/granitemoe/test_modeling_granitemoe.py b/tests/models/granitemoe/test_modeling_granitemoe.py index 97af65667ed048..31307865a77da7 100644 --- a/tests/models/granitemoe/test_modeling_granitemoe.py +++ b/tests/models/granitemoe/test_modeling_granitemoe.py @@ -14,14 +14,12 @@ # limitations under the License. """Testing suite for the PyTorch GraniteMoe model.""" -import tempfile import unittest from parameterized import parameterized from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed from transformers.testing_utils import ( - require_flash_attn, require_read_token, require_torch, require_torch_gpu, @@ -416,33 +414,6 @@ def test_model_rope_scaling(self): with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - @require_flash_attn - @require_torch_gpu - @slow - def test_use_flash_attention_2_true(self): - """ - NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended. - """ - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - with tempfile.TemporaryDirectory() as tmp_dir: - model = model_class(config) - model.save_pretrained(tmp_dir) - - new_model = GraniteMoeForCausalLM.from_pretrained( - tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16 - ).to("cuda") - - self.assertTrue(new_model.config._attn_implementation == "flash_attention_2") - - has_flash = False - for name, submodule in new_model.named_modules(): - if "FlashAttention" in submodule.__class__.__name__: - has_flash = True - break - if not has_flash: - raise ValueError("The flash model should have flash attention layers") - @require_torch_gpu class GraniteMoeIntegrationTest(unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 78e42e6ba71f2f..feca640bb4a119 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -14,10 +14,8 @@ # limitations under the License. """Testing suite for the PyTorch LLaMA model.""" -import tempfile import unittest -import pytest from packaging import version from parameterized import parameterized @@ -25,7 +23,6 @@ from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( cleanup, - require_flash_attn, require_read_token, require_torch, require_torch_accelerator, @@ -543,34 +540,6 @@ def _reinitialize_config(base_config, new_kwargs): with self.assertRaises(KeyError): config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" - @require_flash_attn - @require_torch_gpu - @slow - @pytest.mark.flash_attn_test - def test_use_flash_attention_2_true(self): - """ - NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended. - """ - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - with tempfile.TemporaryDirectory() as tmp_dir: - model = model_class(config) - model.save_pretrained(tmp_dir) - - new_model = LlamaForCausalLM.from_pretrained( - tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16 - ).to("cuda") - - self.assertTrue(new_model.config._attn_implementation == "flash_attention_2") - - has_flash = False - for name, submodule in new_model.named_modules(): - if "FlashAttention" in submodule.__class__.__name__: - has_flash = True - break - if not has_flash: - raise ValueError("The flash model should have flash attention layers") - @require_torch_gpu class LlamaIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5f053c20ff7938..f150477c6231f4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2769,8 +2769,6 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, n attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))]) for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes): - if isinstance(pt_output, DynamicCache): - pt_output = pt_output.to_legacy_cache() self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(fx_outputs, jnp.ndarray): @@ -3612,34 +3610,6 @@ def test_model_is_small(self): num_params < 1000000 ), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max." - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - def test_flash_attn_2_conversion(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" - ).to(torch_device) - - for _, module in model.named_modules(): - if "FlashAttention" in module.__class__.__name__: - return - - self.assertTrue(False, "FlashAttention2 modules not found in model") - @require_flash_attn @require_torch_gpu @mark.flash_attn_test From 0ade1caa356dce6b70ef8293addeb0898f177206 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 19 Dec 2024 11:22:37 -0500 Subject: [PATCH 10/23] Modernbert Release Fixes (#35344) * fix ForSequenceClassification * unmodularize rope layer * fix linting warning * Avoid complex PoolingHead, only one prediction head needed --------- Co-authored-by: Tom Aarsen --- .../models/modernbert/modeling_modernbert.py | 39 ++++------- .../models/modernbert/modular_modernbert.py | 69 +++++++++++-------- 2 files changed, 55 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index db8d98893f96fe..237fba6f645fa5 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -610,8 +610,6 @@ def init_weight(module: nn.Module, std: float): init_weight(module.Wqkv, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBertPredictionHead): - init_weight(module.dense, stds["in"]) - elif isinstance(module, ModernBertPoolingHead): init_weight(module.dense, stds["out"]) elif isinstance(module, ModernBertForMaskedLM): init_weight(module.decoder, stds["out"]) @@ -1109,26 +1107,6 @@ def forward( ) -class ModernBertPoolingHead(nn.Module): - def __init__(self, config: ModernBertConfig): - super().__init__() - self.config = config - self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.act = ACT2FN[config.classifier_activation] - self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - self.drop = torch.nn.Dropout(config.classifier_dropout) - - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - if self.config.classifier_pooling == "cls": - hidden_states = hidden_states[:, 0] - elif self.config.classifier_pooling == "mean": - hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( - dim=1, keepdim=True - ) - - return self.drop(self.norm(self.act(self.dense(hidden_states)))) - - @add_start_docstrings( "The ModernBert Model with a sequence classification head on top that performs pooling.", MODERNBERT_START_DOCSTRING, @@ -1140,7 +1118,8 @@ def __init__(self, config: ModernBertConfig): self.config = config self.model = ModernBertModel(config) - self.head = ModernBertPoolingHead(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1194,7 +1173,15 @@ def forward( ) last_hidden_state = outputs[0] - pooled_output = self.head(last_hidden_state, attention_mask) + if self.config.classifier_pooling == "cls": + last_hidden_state = last_hidden_state[:, 0] + elif self.config.classifier_pooling == "mean": + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) logits = self.classifier(pooled_output) loss = None @@ -1242,7 +1229,8 @@ def __init__(self, config: ModernBertConfig): self.num_labels = config.num_labels self.model = ModernBertModel(config) - self.drop = nn.Dropout(config.classifier_dropout) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1293,6 +1281,7 @@ def forward( ) last_hidden_state = outputs[0] + last_hidden_state = self.head(last_hidden_state) last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 3c23f9178b1b51..dac356146f3015 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -40,7 +40,7 @@ logging, ) from ...utils.import_utils import is_triton_available -from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb +from ..gemma.modeling_gemma import apply_rotary_pos_emb if is_flash_attn_2_available(): @@ -493,8 +493,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.Wo(self.drop(self.act(input) * gate)) -class ModernBertRotaryEmbedding(GemmaRotaryEmbedding): - pass +class ModernBertRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def eager_attention_forward( @@ -811,8 +835,6 @@ def init_weight(module: nn.Module, std: float): init_weight(module.Wqkv, stds["in"]) init_weight(module.Wo, stds["out"]) elif isinstance(module, ModernBertPredictionHead): - init_weight(module.dense, stds["in"]) - elif isinstance(module, ModernBertPoolingHead): init_weight(module.dense, stds["out"]) elif isinstance(module, ModernBertForMaskedLM): init_weight(module.decoder, stds["out"]) @@ -1238,26 +1260,6 @@ def forward( ) -class ModernBertPoolingHead(nn.Module): - def __init__(self, config: ModernBertConfig): - super().__init__() - self.config = config - self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) - self.act = ACT2FN[config.classifier_activation] - self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) - self.drop = torch.nn.Dropout(config.classifier_dropout) - - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - if self.config.classifier_pooling == "cls": - hidden_states = hidden_states[:, 0] - elif self.config.classifier_pooling == "mean": - hidden_states = (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( - dim=1, keepdim=True - ) - - return self.drop(self.norm(self.act(self.dense(hidden_states)))) - - @add_start_docstrings( "The ModernBert Model with a sequence classification head on top that performs pooling.", MODERNBERT_START_DOCSTRING, @@ -1269,7 +1271,8 @@ def __init__(self, config: ModernBertConfig): self.config = config self.model = ModernBertModel(config) - self.head = ModernBertPoolingHead(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1323,7 +1326,15 @@ def forward( ) last_hidden_state = outputs[0] - pooled_output = self.head(last_hidden_state, attention_mask) + if self.config.classifier_pooling == "cls": + last_hidden_state = last_hidden_state[:, 0] + elif self.config.classifier_pooling == "mean": + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) logits = self.classifier(pooled_output) loss = None @@ -1371,7 +1382,8 @@ def __init__(self, config: ModernBertConfig): self.num_labels = config.num_labels self.model = ModernBertModel(config) - self.drop = nn.Dropout(config.classifier_dropout) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing @@ -1422,6 +1434,7 @@ def forward( ) last_hidden_state = outputs[0] + last_hidden_state = self.head(last_hidden_state) last_hidden_state = self.drop(last_hidden_state) logits = self.classifier(last_hidden_state) From f42084e6411c39b74309af4a7d6ed640c01a4c9e Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Thu, 19 Dec 2024 23:45:52 +0100 Subject: [PATCH 11/23] [`docs`] Add link to ModernBERT Text Classification GLUE finetuning script (#35347) Add link to ModernBERT Text Classification GLUE finetuning script --- docs/source/en/model_doc/modernbert.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md index ab09f38ff12154..b641d7f3f58199 100644 --- a/docs/source/en/model_doc/modernbert.md +++ b/docs/source/en/model_doc/modernbert.md @@ -50,6 +50,10 @@ The original code can be found [here](https://github.com/answerdotai/modernbert) A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ModernBert. + + +- A notebook on how to [finetune for General Language Understanding Evaluation (GLUE) with Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb), also available as a Google Colab [notebook](https://colab.research.google.com/github/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb). 🌎 + - A script on how to [finetune for text similarity or information retrieval with Sentence Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_st.py). 🌎 From ff9141bb85f22e7b200f0fbed76fd3641990ed7b Mon Sep 17 00:00:00 2001 From: Nikos Antoniou <57691096+nikosanto13@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:22:05 +0200 Subject: [PATCH 12/23] fix onnx export of speech foundation models (#34224) * added expanded attention/padding masks prior to indexing the hidden_states * consistency fix in WavLMForSequenceClassification --------- Co-authored-by: Nikos Antoniou --- .../models/data2vec/modeling_data2vec_audio.py | 3 ++- src/transformers/models/hubert/modeling_hubert.py | 3 ++- src/transformers/models/sew/modeling_sew.py | 9 +++++---- src/transformers/models/sew_d/modeling_sew_d.py | 6 ++++-- src/transformers/models/unispeech/modeling_unispeech.py | 3 ++- .../models/unispeech_sat/modeling_unispeech_sat.py | 3 ++- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 3 ++- .../models/wav2vec2_bert/modeling_wav2vec2_bert.py | 3 ++- .../wav2vec2_conformer/modeling_wav2vec2_conformer.py | 6 ++++-- src/transformers/models/wavlm/modeling_wavlm.py | 9 ++++++--- 10 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 03102d22ca0d77..801bd19fca3b60 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -1421,7 +1421,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 1629f7d4f3feae..f2700836789ebd 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -1629,7 +1629,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 1959d21e1d5d94..8dc3e2297d4525 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -882,15 +882,15 @@ def forward( all_self_attentions = () if output_attentions else None if attention_mask is not None: + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) if self._use_flash_attention_2: # make sure padded tokens output 0 - hidden_states[~attention_mask] = 0.0 + hidden_states[~expand_attention_mask] = 0.0 # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: # make sure padded tokens output 0 - hidden_states[~attention_mask] = 0.0 - + hidden_states[~expand_attention_mask] = 0.0 input_lengths = (attention_mask.long()).sum(-1) # apply pooling formula to get real output_lengths output_lengths = input_lengths // self.config.squeeze_factor @@ -1473,7 +1473,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 5cccc0218e6ccf..2df687f4cc362a 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1175,7 +1175,8 @@ def forward( ) else: # make sure padded tokens output 0 - hidden_states[~attention_mask.bool()] = 0.0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask.bool()] = 0.0 input_lengths = (attention_mask.long()).sum(-1) # apply pooling formula to get real output_lengths @@ -1721,7 +1722,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index d1496432279527..f355eb03bdb82f 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -1876,7 +1876,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 49551b73577ad7..0fd6e7cb2c04e1 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -1886,7 +1886,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ca743e1eaef3af..e4df2e6ae3b718 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -2376,7 +2376,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index 6f1d5576df7316..7774c7a4069d02 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -1359,7 +1359,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 933bf8f6dc0bcd..494654a6774754 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -878,7 +878,8 @@ def forward( if attention_mask is not None: # make sure padded tokens output 0 - hidden_states[~attention_mask] = 0.0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0.0 # extend attention_mask attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) @@ -1791,7 +1792,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 4df192fda5efa3..3e5e3790005377 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -691,7 +691,8 @@ def forward( if attention_mask is not None: # make sure padded tokens output 0 - hidden_states[~attention_mask] = 0.0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -776,7 +777,8 @@ def forward( if attention_mask is not None: # make sure padded tokens are not attended to - hidden_states[~attention_mask] = 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -1508,7 +1510,8 @@ def forward( pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) - hidden_states[~padding_mask] = 0.0 + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) From 5a2aedca1e9b1d7cc7c6ce3e65034c6df7863a95 Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Fri, 20 Dec 2024 03:27:47 -0500 Subject: [PATCH 13/23] [`Mamba2`] Fix caching, slow path, and multi-gpu (#35154) * fixup mamba2 - caching and several other small fixes * fixup cached forward * correct fix this time * fixup cache - we do not need to extend the attn mask it's handled by generate (gives total ids + mask at each step) * remove unnecessary (un)squeeze * fixup cache position * simplify a few things * [run-slow] mamba2 * multi gpu attempt two * [run-slow] mamba2 * [run-slow] mamba2 * [run-slow] mamba2 * [run-slow] mamba2 * add newer slow path fix * [run-slow] mamba2 --- .../models/mamba2/modeling_mamba2.py | 362 ++++++++++-------- tests/models/mamba2/test_modeling_mamba2.py | 82 +++- 2 files changed, 265 insertions(+), 179 deletions(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index c312b9b94351d2..550eeb7f9665e4 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -44,14 +44,22 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined else: - selective_state_update = None + mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) +is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) +) _CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1" _CONFIG_FOR_DOC = "Mamba2Config" @@ -111,6 +119,17 @@ def segment_sum(input_tensor): return tensor_segsum +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + class Mamba2Cache: """ Arguments: @@ -120,51 +139,69 @@ class Mamba2Cache: device: torch.device Attributes: - seqlen_offset: int - dtype: torch.dtype - conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size] - ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config. + n_groups: (`int`): + Model's number of groups taken from the config - similar to tensor parallel in Transformer. + state_size: (`int`): + Model's SSM state size taken from config. + num_heads: (`int`): + The number of heads used in the linear attention / SSM. + head_dim: (`int`): + The respective dimension of the heads used in the linear attention / SSM. + intermediate_size: (`int`): + Model's intermediate_size based on (expand * hidden_dim) from config. + conv_states: (`torch.Tensor`): + A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states. """ def __init__( self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): - self.seqlen_offset = 0 self.dtype = dtype self.conv_kernel_size = config.conv_kernel + self.n_groups = config.n_groups + self.state_size = config.state_size + self.num_heads = config.num_heads + self.head_dim = config.head_dim self.intermediate_size = int(config.expand * config.hidden_size) - self.conv_states = { - i: torch.zeros( - batch_size, - self.intermediate_size + 2 * config.n_groups * config.state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( - batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype - ) - for i in range(config.num_hidden_layers) - } - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] + self.conv_states = torch.zeros( + config.num_hidden_layers, + batch_size, + self.intermediate_size + 2 * self.n_groups * self.state_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + self.ssm_states = torch.zeros( + config.num_hidden_layers, + batch_size, + self.num_heads, + self.head_dim, + self.state_size, + device=device, + dtype=dtype, + ) def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False ) -> torch.Tensor: - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) return self.conv_states[layer_idx] + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + def reset(self): self.conv_states.zero_() self.ssm_states.zero_() @@ -269,19 +306,27 @@ def cuda_kernels_forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): - # set up dimensions for reshapes later + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size - d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads - - # getting projected states from cache if it exists - if cache_params is not None and cache_params.seqlen_offset > 0: - in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 - split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] - _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1) + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + # 2. Convolution sequence transformation hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, cache_params.conv_states[self.layer_idx], @@ -295,8 +340,9 @@ def cuda_kernels_forward( [self.intermediate_size, groups_time_state_size, groups_time_state_size], dim=-1, ) - A = -torch.exp(self.A_log.float()) # (nheads,) + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt = dt[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) @@ -318,20 +364,18 @@ def cuda_kernels_forward( ) hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection out = self.out_proj(hidden_states)[:, None, ...] - # if no cache is found, calling the kernel + + # Fused calculations or step by step if no initialized cache is found else: - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + # 2-4. Fused kernel for conv1d, SSM, and the final projection if self.training and cache_params is None: - out, ssm_state = mamba_split_conv1d_scan_combined( + out = mamba_split_conv1d_scan_combined( projected_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, @@ -348,41 +392,50 @@ def cuda_kernels_forward( headdim=self.head_dim, ngroups=self.n_groups, norm_before_gate=False, - return_final_states=True, + return_final_states=False, **dt_limit_kwargs, ) else: - gate, hidden_states_B_C, time_step = torch.split( - projected_states, - [self.intermediate_size, self.conv_dim, self.num_heads], - dim=-1, + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - # 1D Convolution - if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + ) + + if self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( - self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] - ) # (B, L, self.d_inner + 2 * ngroups * d_state) + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) else: hidden_states_B_C = causal_conv1d_fn( x=hidden_states_B_C.transpose(1, 2), weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, - ).transpose(1, 2)[:, :seq_len] + ).transpose(1, 2) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( hidden_states_B_C, [self.intermediate_size, groups_time_state_size, groups_time_state_size], dim=-1, ) - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + # 3. SSM transformation scan_output, ssm_state = mamba_chunk_scan_combined( hidden_states.view(batch_size, seq_len, -1, self.head_dim), - time_step, + dt, A, B.view(batch_size, seq_len, self.n_groups, -1), C.view(batch_size, seq_len, self.n_groups, -1), @@ -395,11 +448,16 @@ def cuda_kernels_forward( dt_softplus=True, **dt_limit_kwargs, ) + + # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) # Multiply "gate" branch and apply extra normalization layer scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection out = self.out_proj(scan_output) return out @@ -407,60 +465,64 @@ def cuda_kernels_forward( def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype - # Gated MLP's linear projection - projected_states = self.in_proj(input_states.squeeze(1)) - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 - _, _, gate, hidden_states, dt = projected_states.split( + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) - # Convolution sequence transformation - if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if cache_params.seqlen_offset > 0: - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) - conv_state = nn.functional.pad( - hidden_states, - (self.conv_kernel_size - hidden_states.shape[-1], 0) - ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 ) - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) - hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states.device + # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation - dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt[:, 0, :][:, None, ...] dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) # [num_heads] -> [num_heads, head_dim] dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) - dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) # [bsz, num_heads, head_dim, state_size] - dA = torch.exp(dt[..., None] * A) + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) # Discretize B # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> @@ -474,11 +536,12 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, # Discretize x into dB # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) - dBx = dB * hidden_states[..., None] + dBx = (dB * hidden_states[..., None]).to(device=cache_device) # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx ) # Subsequent output @@ -488,7 +551,7 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] @@ -505,9 +568,9 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, else: # begin ssd naive implementation without einsums dt = nn.functional.softplus(dt + self.dt_bias) - dt = torch.clamp(dt, self.time_step_min) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() - B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) @@ -522,7 +585,6 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, # Rearrange into blocks/chunks hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] - # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] A = A.permute(0, 3, 1, 2) A_cumsum = torch.cumsum(A, dim=-1) @@ -531,45 +593,43 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, # This is the analog of a causal mask L = torch.exp(segment_sum(A)) - # First, contraction of C and B to get G (attention-weights like) - G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) - - # Step 2: Compute M, equivalent to applying attention mask to weights + # Compute M, equivalent to applying attention mask to weights M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] M = M_intermediate.sum(dim=-1) - # Step 3: Compute Y_diag (apply to values) - Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) - decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) - B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] - # permute back B * decay states - states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.seqlen_offset > 0: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) else: previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) - - states_permuted = states.permute(0, 2, 1, 3, 4) - result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) - new_states = result.permute(0, 2, 1, 3, 4) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) states, ssm_state = new_states[:, :-1], new_states[:, -1] - # Compute state -> output conversion per chunk + # 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) - # compute Yoff C_times_states = (C[..., None, :] * states[:, :, None, ...]) state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) y = Y_diag + Y_off # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) @@ -579,8 +639,10 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, if pad_size > 0: y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) + + # Init cache if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) scan_output = self.norm(y, gate) @@ -916,9 +978,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if use_cache: - cache_params.seqlen_offset += inputs_embeds.shape[1] - hidden_states = self.norm_f(hidden_states) if output_hidden_states: @@ -975,10 +1034,6 @@ def prepare_inputs_for_generation( ): # Overwitten -- uses `cache_params` as opposed to `past_key_values` - if inputs_embeds is not None: - past_len = inputs_embeds.shape[1] + input_ids.shape[1] - else: - past_len = input_ids.shape[1] if use_cache: # `cache_position` should have been initialized in `generate` if cache_position is None: @@ -987,33 +1042,18 @@ def prepare_inputs_for_generation( "`model.generate`, you are responsible for passing in a valid `cache_position` if " "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" ) - # how do we detect that we are in decoding without cache? if cache_position[0] > 0: input_ids = input_ids[:, -1][..., None] - attention_mask = attention_mask[:, -1][..., None] + + if attention_mask is not None: + attention_mask = None else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage # considering padding will be applied when input length is shorter, and truncation # will be applied when it is longer, so it will be equivalent to always have it match # the length of `cache_params.conv_states`, which is `config.conv_kernel` - cache_position = torch.arange(0, past_len, device=input_ids.device) - # if the cache is not used, we also do have to extend the attention mask here - # TODO there is likely a cleverer way to do this - extended_mask = torch.ones( - attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device - ) - attention_mask = torch.cat([attention_mask, extended_mask], dim=1) - cache_params = None - - if attention_mask.shape[1] < past_len: - # we have to update manually the attention mask if - # we are in decoding without cache - # and we don't have position_ids here - # TODO but we should be able to use cache_position though at a later time - extended_mask = torch.ones( - attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device - ) - attention_mask = torch.cat([attention_mask, extended_mask], dim=1) + cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + if inputs_embeds is not None and cache_params is None: model_inputs = {"inputs_embeds": inputs_embeds} else: diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 9b3a9563b58ddc..c2ef68f2614ea5 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, Mamba2Config, is_torch_available from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device +from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -103,6 +104,10 @@ def prepare_config_and_inputs( ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + # Only left padding is valid + attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long) + attention_mask[0, :1] = 0 + sequence_labels = None token_labels = None choice_labels = None @@ -118,7 +123,7 @@ def prepare_config_and_inputs( return ( config, input_ids, - None, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -158,6 +163,56 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = {"input_ids": input_ids} return config, inputs_dict + def create_and_check_mamba2_caching(self, config, input_ids, attention_mask, *args): + model = Mamba2Model(config=config) + model.to(torch_device) + model.eval() + + output_whole = model(input_ids, attention_mask=attention_mask).last_hidden_state + + outputs = model( + input_ids[:, :-1], + attention_mask=attention_mask[:, :-1], + use_cache=True, + cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device), + ) + output_one = outputs.last_hidden_state + + # Using the state computed on the first inputs, we will get the same output + outputs = model( + input_ids[:, -1:], + attention_mask=attention_mask[:, -1:], + use_cache=True, + cache_params=outputs.cache_params, + cache_position=torch.arange(config.conv_kernel, config.conv_kernel + 1, device=input_ids.device), + ) + output_two = outputs.last_hidden_state + + self.parent.assertTrue( + torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-3, rtol=1e-3) + ) + + def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args, gradient_checkpointing=False): + model = Mamba2Model(config) + model.eval() + + if not (is_mamba_2_ssm_available() and is_causal_conv1d_available()): + self.parent.skipTest( + "This test needs the Mamba2 fast path. Skipping as the necessary packages have not been found." + ) + if torch_device != "cuda": + self.parent.skipTest("This test needs the Mamba2 fast path. Skipping as we need a cuda capable device.") + + model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + + token_emb = model.embeddings(input_ids) + outputs_fast = model.layers[0].mixer.cuda_kernels_forward(token_emb) + outputs_slow = model.layers[0].mixer.torch_forward(token_emb) + + self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3)) + @unittest.skipIf( not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" @@ -184,6 +239,14 @@ def setUp(self): self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) + def test_mamba2_caching(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mamba2_caching(*config_and_inputs) + + def test_mamba2_slow_vs_fast_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs) + def test_initialization(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -199,23 +262,6 @@ def test_initialization(self): def test_tied_weights_keys(self): pass - @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - def test_generate_without_input_ids(self): - pass - - @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - @parameterized.expand([("greedy", 1), ("beam search", 2)]) - def test_generate_from_inputs_embeds(self, _, num_beams): - pass - - @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - def test_greedy_generate_dict_outputs_use_cache(self): - pass - - @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search") - def test_beam_search_generate_dict_outputs_use_cache(self): - pass - @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") def test_multi_gpu_data_parallel_forward(self): pass From 4e27a4009d3f9d4e44e9be742e8cd742daf074f4 Mon Sep 17 00:00:00 2001 From: wejoncy <247153481@qq.com> Date: Fri, 20 Dec 2024 16:45:53 +0800 Subject: [PATCH 14/23] FEAT : Adding VPTQ quantization method to HFQuantizer (#34770) * init vptq * add integration * add vptq support fix readme * add tests && format * format * address comments * format * format * address comments * format * address comments * remove debug code * Revert "remove debug code" This reverts commit ed3b3eaaba82caf58cb3aa6e865d98e49650cf66. * fix test --------- Co-authored-by: Yang Wang --- .../Dockerfile | 3 + docs/source/ar/_toctree.yml | 2 + docs/source/en/_toctree.yml | 2 + docs/source/en/llm_optims.md | 2 +- docs/source/en/main_classes/quantization.md | 4 + docs/source/en/quantization/overview.md | 3 +- docs/source/en/quantization/vptq.md | 111 ++++++++++ docs/source/ko/_toctree.yml | 4 + docs/source/ko/llm_optims.md | 2 +- docs/source/ko/main_classes/quantization.md | 4 + src/transformers/__init__.py | 2 + src/transformers/integrations/__init__.py | 2 + src/transformers/integrations/vptq.py | 101 +++++++++ src/transformers/quantizers/auto.py | 4 + src/transformers/quantizers/quantizer_vptq.py | 98 +++++++++ src/transformers/testing_utils.py | 8 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 6 + src/transformers/utils/quantization_config.py | 97 +++++++++ .../quantization/vptq_integration/__init__.py | 0 .../vptq_integration/test_vptq.py | 194 ++++++++++++++++++ 21 files changed, 647 insertions(+), 3 deletions(-) create mode 100644 docs/source/en/quantization/vptq.md create mode 100644 src/transformers/integrations/vptq.py create mode 100644 src/transformers/quantizers/quantizer_vptq.py create mode 100644 tests/quantization/vptq_integration/__init__.py create mode 100644 tests/quantization/vptq_integration/test_vptq.py diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index 089be4a4460101..3cb2acdc53bb1a 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -50,6 +50,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/pef # Add aqlm for quantization testing RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2 +# Add vptq for quantization testing +RUN python3 -m pip install --no-cache-dir vptq + # Add hqq for quantization testing RUN python3 -m pip install --no-cache-dir hqq diff --git a/docs/source/ar/_toctree.yml b/docs/source/ar/_toctree.yml index 138d3a1bd8aa08..287f4dffbb384e 100644 --- a/docs/source/ar/_toctree.yml +++ b/docs/source/ar/_toctree.yml @@ -157,6 +157,8 @@ # title: AWQ # - local: quantization/aqlm # title: AQLM +# - local: quantization/vptq +# title: VPTQ # - local: quantization/quanto # title: Quanto # - local: quantization/eetq diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8138dd41d80c12..18de03e1df8016 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -167,6 +167,8 @@ title: AWQ - local: quantization/aqlm title: AQLM + - local: quantization/vptq + title: VPTQ - local: quantization/quanto title: Quanto - local: quantization/eetq diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index e97ace8a625050..17ebb841de7a39 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -473,7 +473,7 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable Quantization reduces the size of the LLM weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if you're constrained by your GPUs memory. If you aren't limited by your GPU, you don't necessarily need to quantize your model because it can incur a small latency cost (except for AWQ and fused AWQ modules) due to the extra step required to quantize and dequantize the weights. > [!TIP] -> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes. +> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, VPTQ, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes. Use the Model Memory Calculator below to estimate and compare how much memory is required to load a model. For example, try estimating how much memory it costs to load [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1). diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index 3f44569697777b..9b500b69374c88 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -34,6 +34,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide. [[autodoc]] AqlmConfig +## VptqConfig + +[[autodoc]] VptqConfig + ## AwqConfig [[autodoc]] AwqConfig diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 0fb72d26058e55..f3508aed0674f6 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -58,6 +58,7 @@ Use the table below to help you decide which quantization method to use. | [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto | | [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | | [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | 🔴 | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao | +| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ | @@ -71,4 +72,4 @@ We value your feedback to help identify bugs before the full release! Check out \** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships. - + \ No newline at end of file diff --git a/docs/source/en/quantization/vptq.md b/docs/source/en/quantization/vptq.md new file mode 100644 index 00000000000000..b86e82f0a3503d --- /dev/null +++ b/docs/source/en/quantization/vptq.md @@ -0,0 +1,111 @@ + + +# VPTQ + +> [!TIP] +> Try VPTQ on [Hugging Face](https://huggingface.co/spaces/microsoft/VPTQ)! +> Try VPTQ on [Google Colab](https://colab.research.google.com/github/microsoft/VPTQ/blob/main/notebooks/vptq_example.ipynb)! +> Know more about VPTQ on [ArXiv](https://arxiv.org/pdf/2409.17066)! + +Vector Post-Training Quantization ([VPTQ](https://github.com/microsoft/VPTQ)) is a novel Post-Training Quantization method that leverages Vector Quantization to high accuracy on LLMs at an extremely low bit-width (<2-bit). VPTQ can compress 70B, even the 405B model, to 1-2 bits without retraining and maintain high accuracy. + +- Better Accuracy on 1-2 bits, (405B @ <2bit, 70B @ 2bit) +- Lightweight Quantization Algorithm: only cost ~17 hours to quantize 405B Llama-3.1 +- Agile Quantization Inference: low decode overhead, best throughput, and TTFT + +Inference support for VPTQ is released in the `vptq` library. Make sure to install it to run the models: +```bash +pip install vptq +``` + +The library provides efficient kernels for NVIDIA/AMD GPU inference. + +To run VPTQ models simply load a model that has been quantized with VPTQ: + +## Inference example +**Run Llama 3.1 70b on RTX4090 (24G @ ~2bits) in real time** +![Llama3 1-70b-prompt](https://github.com/user-attachments/assets/d8729aca-4e1d-4fe1-ac71-c14da4bdd97f) + + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM + +quantized_model = AutoModelForCausalLM.from_pretrained( + "VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft", + torch_dtype="auto", + device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained("VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft") +input_ids = tokenizer("hello, it's me", return_tensors="pt").to("cuda") +out = model.generate(**input_ids, max_new_tokens=32, do_sample=False) +``` + +## Quantize your own model +VPTQ algorithm early-released at [VPTQ ](https://github.com/microsoft/VPTQ/tree/algorithm), +and checkout the [tutorial](https://github.com/microsoft/VPTQ/blob/algorithm/algorithm.md). + +## Early Results from Tech Report +VPTQ achieves better accuracy and higher throughput with lower quantization overhead across models of different sizes. The following experimental results are for reference only; VPTQ can achieve better outcomes under reasonable parameters, especially in terms of model accuracy and inference speed. + + +| Model | bitwidth | W2↓ | C4↓ | AvgQA↑ | tok/s↑ | mem(GB) | cost/h↓ | +| ----------- | -------- | ---- | ---- | ------ | ------ | ------- | ------- | +| LLaMA-2 7B | 2.02 | 6.13 | 8.07 | 58.2 | 39.9 | 2.28 | 2 | +| | 2.26 | 5.95 | 7.87 | 59.4 | 35.7 | 2.48 | 3.1 | +| LLaMA-2 13B | 2.02 | 5.32 | 7.15 | 62.4 | 26.9 | 4.03 | 3.2 | +| | 2.18 | 5.28 | 7.04 | 63.1 | 18.5 | 4.31 | 3.6 | +| LLaMA-2 70B | 2.07 | 3.93 | 5.72 | 68.6 | 9.7 | 19.54 | 19 | +| | 2.11 | 3.92 | 5.71 | 68.7 | 9.7 | 20.01 | 19 | + + + +## More Models in [VPTQ-community](https://huggingface.co/VPTQ-community) + +⚠️ The repository only provides a method of model quantization algorithm. + +⚠️ The open-source community VPTQ-community provides models based on the technical report and quantization algorithm. + + + +**Quick Estimation of Model Bitwidth (Excluding Codebook Overhead)**: + +- **Model Naming Convention**: The model's name includes the **vector length** $v$, **codebook (lookup table) size**, and **residual codebook size**. For example, "Meta-Llama-3.1-70B-Instruct-v8-k65536-256-woft" is "Meta-Llama-3.1-70B-Instruct", where: + - **Vector Length**: 8 + - **Number of Centroids**: 65536 (2^16) + - **Number of Residual Centroids**: 256 (2^8) +- **Equivalent Bitwidth Calculation**: + - **Index**: log2(65536) = 16 / 8 = 2 bits + - **Residual Index**: log2(256) = 8 / 8 = 1 bit + - **Total Bitwidth**: 2 + 1 = 3 bits +- **Model Size Estimation**: 70B * 3 bits / 8 bits per Byte = 26.25 GB + +- **Note**: This estimate does not include the size of the codebook (lookup table), other parameter overheads, and the padding overhead for storing indices. For the detailed calculation method, please refer to **Tech Report Appendix C.2**. + + +| Model Series | Collections | (Estimated) Bit per weight | +| :--------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Llama 3.1 Nemotron 70B Instruct HF | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-nemotron-70b-instruct-hf-without-finetune-671730b96f16208d0b3fe942) | [4 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-0-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-16384-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-1024-woft) [1.5 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-256-woft) | +| Llama 3.1 8B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-8b-instruct-without-finetune-66f2b70b1d002ceedef02d2e) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-65536-woft) [3.5 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-4096-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-256-woft) [2.3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v12-k65536-4096-woft) | +| Llama 3.1 70B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-70b-instruct-without-finetune-66f2bf454d3dd78dfee2ff11) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-256-woft) [2.25 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-4-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-0-woft) [1.93 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-32768-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k32768-0-woft) [1.75 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k16384-0-woft) | +| Llama 3.1 405B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-405b-instruct-without-finetune-66f4413f9ba55e1a9e52cfb0) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k65536-256-woft) [2 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-65536-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k32768-32768-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-1024-woft) [1.5 bits (1)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k4096-0-woft) [1.5 bits (2)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-256-woft) [1.43 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-128-woft) [1.375 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-64-woft) | +| Mistral Large Instruct 2407 (123B) | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-mistral-large-instruct-2407-without-finetune-6711ebfb7faf85eed9cceb16) | [4 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-0-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-16384-woft) [1.75 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-4096-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-1024-woft) [1.5 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-256-woft) | +| Qwen 2.5 7B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-7b-instruct-without-finetune-66f3e9866d3167cc05ce954a) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k256-256-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v16-k65536-65536-woft) | +| Qwen 2.5 14B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-14b-instruct-without-finetune-66f827f83c7ffa7931b8376c) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k256-256-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v16-k65536-65536-woft) | +| Qwen 2.5 32B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-32b-instruct-without-finetune-66fe77173bf7d64139f0f613) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k256-256-woft) | +| Qwen 2.5 72B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-72b-instruct-without-finetune-66f3bf1b3757dfa1ecb481c0) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-256-woft) [2.38 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k1024-512-woft) [2.25 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k512-512-woft) [2.25 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-4-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-0-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v16-k65536-65536-woft) [1.94 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v16-k65536-32768-woft) | +| Reproduced from the tech report | [HF 🤗](https://huggingface.co/collections/VPTQ-community/reproduced-vptq-tech-report-baseline-66fbf1dffe741cc9e93ecf04) | Results from the open source community for reference only, please use them responsibly. | +| Hessian and Inverse Hessian Matrix | [HF 🤗](https://huggingface.co/collections/VPTQ-community/hessian-and-invhessian-checkpoints-66fd249a104850d17b23fd8b) | Collected from RedPajama-Data-1T-Sample, following [Quip#](https://github.com/Cornell-RelaxML/quip-sharp/blob/main/quantize_llama/hessian_offline_llama.py) \ No newline at end of file diff --git a/docs/source/ko/_toctree.yml b/docs/source/ko/_toctree.yml index 7e9567769cca1a..54740610ee1148 100644 --- a/docs/source/ko/_toctree.yml +++ b/docs/source/ko/_toctree.yml @@ -151,6 +151,8 @@ title: AWQ - local: in_translation title: (번역중) AQLM + - local: in_translation + title: (번역중) VPTQ - local: in_translation title: (번역중) Quanto - local: in_translation @@ -173,6 +175,8 @@ title: (번역중) AWQ - local: in_translation title: (번역중) AQLM + - local: in_translation + title: (번역중) VPTQ - local: quantization/quanto title: Quanto - local: quantization/eetq diff --git a/docs/source/ko/llm_optims.md b/docs/source/ko/llm_optims.md index 656ed53584c226..99eabc19ce860a 100644 --- a/docs/source/ko/llm_optims.md +++ b/docs/source/ko/llm_optims.md @@ -375,7 +375,7 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable 양자화는 LLM 가중치를 더 낮은 정밀도로 저장하여 크기를 줄입니다. 이는 메모리 사용량을 줄이며 GPU 메모리에 제약이 있는 경우 추론을 위해 LLM을 로드하는 것을 더 용이하게 합니다. GPU가 충분하다면, 모델을 양자화할 필요는 없습니다. 추가적인 양자화 및 양자화 해제 단계로 인해 약간의 지연이 발생할 수 있기 때문입니다(AWQ 및 융합 AWQ 모듈 제외). > [!TIP] -> 다양한 양자화 라이브러리(자세한 내용은 [Quantization](./quantization) 가이드를 참조하십시오)가 있습니다. 여기에는 Quanto, AQLM, AWQ 및 AutoGPTQ가 포함됩니다. 사용 사례에 가장 잘 맞는 라이브러리를 사용해 보십시오. 또한 AutoGPTQ와 bitsandbytes를 비교하는 [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) 블로그 게시물을 읽어보는 것을 추천합니다. +> 다양한 양자화 라이브러리(자세한 내용은 [Quantization](./quantization) 가이드를 참조하십시오)가 있습니다. 여기에는 Quanto, AQLM, VPTQ, AWQ 및 AutoGPTQ가 포함됩니다. 사용 사례에 가장 잘 맞는 라이브러리를 사용해 보십시오. 또한 AutoGPTQ와 bitsandbytes를 비교하는 [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) 블로그 게시물을 읽어보는 것을 추천합니다. 아래의 모델 메모리 계산기를 사용하여 모델을 로드하는 데 필요한 메모리를 추정하고 비교해 보십시오. 예를 들어 [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)를 로드하는 데 필요한 메모리를 추정해 보십시오. diff --git a/docs/source/ko/main_classes/quantization.md b/docs/source/ko/main_classes/quantization.md index b1d1730d28d00b..6f793f22107417 100644 --- a/docs/source/ko/main_classes/quantization.md +++ b/docs/source/ko/main_classes/quantization.md @@ -35,6 +35,10 @@ Transformers에서 지원되지 않는 양자화 기법들은 [`HfQuantizer`] [[autodoc]] AqlmConfig +## VptqConfig[[transformers.VptqConfig]] + +[[autodoc]] VptqConfig + ## AwqConfig[[transformers.AwqConfig]] [[autodoc]] AwqConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 600d3d217fa8a9..681bf1a5d16a36 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1000,6 +1000,7 @@ "HqqConfig", "QuantoConfig", "TorchAoConfig", + "VptqConfig", ], } @@ -6017,6 +6018,7 @@ HqqConfig, QuantoConfig, TorchAoConfig, + VptqConfig, ) try: diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 093e0af29844e4..32c828cd6e5b44 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -105,6 +105,7 @@ ], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], + "vptq": ["replace_with_vptq_linear"], } try: @@ -207,6 +208,7 @@ ) from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers + from .vptq import replace_with_vptq_linear try: if not is_torch_available(): diff --git a/src/transformers/integrations/vptq.py b/src/transformers/integrations/vptq.py new file mode 100644 index 00000000000000..aa435517e81ebe --- /dev/null +++ b/src/transformers/integrations/vptq.py @@ -0,0 +1,101 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"VPTQ (Vector Post-Training Quantization) integration file" + +import torch.nn as nn +from accelerate import init_empty_weights +from vptq import VQuantLinear + + +def replace_with_vptq_linear( + model, + quantization_config=None, + modules_to_not_convert=None, + current_key_name=None, + has_been_replaced=False, +): + """ + Public method that recursively replaces the Linear layers of the given model with VPTQ quantized layers. + `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the + conversion has been successfull or not. + + Args: + model (`torch.nn.Module`): + The model to convert, can be any `torch.nn.Module` instance. + quantization_config (`VptqConfig`): + The quantization config object that contains the quantization parameters. + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): + Names of the modules to not convert in `VQuantLinear`. In practice we keep the `lm_head` in full precision + for numerical stability reasons. + current_key_name (`list`, *optional*): + A list that contains the current key name. This is used for recursion and should not be passed by the user. + has_been_replaced (`bool`, *optional*): + A boolean that indicates if the conversion has been successful or not. This is used for recursion and + should not be passed by the user. + """ + + modules_to_not_convert = ["lm_head"] if not modules_to_not_convert else modules_to_not_convert + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + layer_name = ".".join(current_key_name) + shared_layer_config = quantization_config.shared_layer_config + config_for_layers = quantization_config.config_for_layers + + if ( + isinstance(module, nn.Linear) + and layer_name not in modules_to_not_convert + and ((layer_name in config_for_layers) or (current_key_name[-1] in shared_layer_config)) + ): + layer_params = config_for_layers.get(layer_name, None) or shared_layer_config.get( + current_key_name[-1], None + ) + + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + model._modules[name] = VQuantLinear( + in_features, + out_features, + vector_lens=layer_params["vector_lens"], + num_centroids=layer_params["num_centroids"], + num_res_centroids=layer_params["num_res_centroids"], + group_num=layer_params["group_num"], + group_size=layer_params["group_size"], + outlier_size=layer_params["outlier_size"], + indices_as_float=layer_params["indices_as_float"], + enable_norm=layer_params["enable_norm"], + enable_perm=layer_params["enable_perm"], + is_indice_packed=True, + enable_proxy_error=False, + bias=module.bias is not None, + ) + has_been_replaced = True + + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = replace_with_vptq_linear( + module, + quantization_config=quantization_config, + modules_to_not_convert=modules_to_not_convert, + current_key_name=current_key_name, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 818072a0d91647..47b54cd27bcebe 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -29,6 +29,7 @@ QuantizationMethod, QuantoConfig, TorchAoConfig, + VptqConfig, ) from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_awq import AwqQuantizer @@ -42,6 +43,7 @@ from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer +from .quantizer_vptq import VptqHfQuantizer AUTO_QUANTIZER_MAPPING = { @@ -57,6 +59,7 @@ "fbgemm_fp8": FbgemmFp8HfQuantizer, "torchao": TorchAoHfQuantizer, "bitnet": BitNetHfQuantizer, + "vptq": VptqHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -72,6 +75,7 @@ "fbgemm_fp8": FbgemmFp8Config, "torchao": TorchAoConfig, "bitnet": BitNetConfig, + "vptq": VptqConfig, } diff --git a/src/transformers/quantizers/quantizer_vptq.py b/src/transformers/quantizers/quantizer_vptq.py new file mode 100644 index 00000000000000..1672c3ebc5a7d3 --- /dev/null +++ b/src/transformers/quantizers/quantizer_vptq.py @@ -0,0 +1,98 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Optional + +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..utils import is_accelerate_available, is_torch_available, is_vptq_available, logging +from ..utils.quantization_config import QuantizationConfigMixin + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class VptqHfQuantizer(HfQuantizer): + """ + Quantizer of the VPTQ method. Enables the loading of prequantized models. + """ + + requires_calibration = True + required_packages = ["vptq"] + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, *args, **kwargs): + if not is_accelerate_available(): + raise ImportError("Using `vptq` quantization requires Accelerate: `pip install accelerate`") + + if not is_vptq_available(): + raise ImportError("Using `vptq` quantization requires VPTQ>=0.0.4: `pip install -U vptq`") + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + if torch.cuda.is_available(): + torch_dtype = torch.float16 + logger.info( + "CUDA available. Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually." + ) + else: + import vptq + + device_availability = getattr(vptq, "device_availability", lambda device: False) + if device_availability("cpu") is True: + raise RuntimeError("No GPU found. Please wait for the next release of VPTQ to use CPU inference") + torch_dtype = torch.float32 + logger.info("No GPU found. Assuming VPTQ inference on CPU and loading the model in `torch.float32`.") + return torch_dtype + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + **kwargs, + ): + """ + we don't have param like modules_to_not_convert to indicate which layers should not be quantized + because `quantization_config` include the layers that should be quantized + """ + from ..integrations import replace_with_vptq_linear + + modules_to_not_convert = kwargs.get("modules_to_not_convert", []) + ( + self.quantization_config.modules_to_not_convert or [] + ) + + replace_with_vptq_linear( + model, + quantization_config=self.quantization_config, + modules_to_not_convert=modules_to_not_convert, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + @property + def is_trainable(self, model: Optional["PreTrainedModel"] = None): + return False + + def is_serializable(self, safe_serialization=None): + return True diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 409f274d41eb17..5b0b9e7686e925 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -142,6 +142,7 @@ is_torchdynamo_available, is_torchvision_available, is_vision_available, + is_vptq_available, strtobool, ) @@ -1142,6 +1143,13 @@ def require_aqlm(test_case): return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case) +def require_vptq(test_case): + """ + Decorator marking a test that requires vptq + """ + return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case) + + def require_eetq(test_case): """ Decorator marking a test that requires eetq diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7fb647b253832e..2edfcdcd101c78 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -233,6 +233,7 @@ is_training_run_on_sagemaker, is_uroman_available, is_vision_available, + is_vptq_available, requires_backends, torch_only_method, ) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 92823a4ee016c3..cfc8b88fd81ed6 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -93,11 +93,13 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ GGUF_MIN_VERSION = "0.10.0" XLA_FSDPV2_MIN_VERSION = "2.2.0" HQQ_MIN_VERSION = "0.2.1" +VPTQ_MIN_VERSION = "0.0.4" _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") _aqlm_available = _is_package_available("aqlm") +_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True) _av_available = importlib.util.find_spec("av") is not None _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") @@ -816,6 +818,10 @@ def is_aqlm_available(): return _aqlm_available +def is_vptq_available(min_version: str = VPTQ_MIN_VERSION): + return _vptq_available and version.parse(_vptq_version) >= version.parse(min_version) + + def is_av_available(): return _av_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 253cc4a0621080..44e47e4f6e65c2 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -39,6 +39,7 @@ class QuantizationMethod(str, Enum): GPTQ = "gptq" AWQ = "awq" AQLM = "aqlm" + VPTQ = "vptq" QUANTO = "quanto" EETQ = "eetq" HQQ = "hqq" @@ -994,6 +995,102 @@ def post_init(self): self.linear_weights_not_to_quantize = [] +@dataclass +class VptqLayerConfig(QuantizationConfigMixin): + """ + This is used to explain vptq config params for each layer + Args: + enable_norm (`bool`, *optional*, defaults to `True`): to control if we have scale/bias for fp-weight + enable_perm (`bool`, *optional*, defaults to `True`): to perm input_channel or not + group_num (`int`, *optional*, defaults to `1`): how many single groups for vector-quantization + group_size (`int`, *optional*, defaults to `-1`): depends on out-features + indices_as_float (`bool`, *optional*, defaults to `False`): for Finetuning + is_indice_packed (`bool`, *optional*, defaults to `True`): should always be True + num_centroids (`list`, *optional*, defaults to `[-1, -1]`): centriod numbers of clusters + num_res_centroids (`list`, *optional*, defaults to `[-1, -1]`): ditto for residual + outlier_size (`int`, *optional*, defaults to `1`): outliers + vector_lens (`list`, *optional*, defaults to `[-1, -1]`): centroid vector length in quantization + """ + + def __init__( + self, + enable_norm: bool = True, + enable_perm: bool = True, + group_num: int = 1, + group_size: int = -1, + in_features: int = -1, + indices_as_float: bool = False, + is_indice_packed: bool = True, + num_centroids: tuple = [-1, -1], + num_res_centroids: tuple = [-1, -1], + out_features: int = -1, + outlier_size: int = 0, + vector_lens: tuple = [-1, -1], + **kwargs, + ): + self.enable_norm = enable_norm + self.enable_perm = enable_perm + self.group_num = group_num + self.group_size = group_size + self.in_features = in_features + self.indices_as_float = indices_as_float + self.is_indice_packed = is_indice_packed + self.num_centroids = num_centroids + self.num_res_centroids = num_res_centroids + self.out_features = out_features + self.outlier_size = outlier_size + self.vector_lens = vector_lens + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + if self.is_indice_packed is False: + raise ValueError("is_indice_packed should always be True") + + +@dataclass +class VptqConfig(QuantizationConfigMixin): + """ + This is a wrapper class about `vptq` parameters. + + Args: + enable_proxy_error (`bool`, *optional*, defaults to `False`): calculate proxy error for each layer + config_for_layers (`Dict`, *optional*, defaults to `{}`): quantization params for each layer + shared_layer_config (`Dict`, *optional*, defaults to `{}`): shared quantization params among layers + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + enable_proxy_error: bool = False, + config_for_layers: Dict[str, Any] = {}, + shared_layer_config: Dict[str, Any] = {}, + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.VPTQ + self.enable_proxy_error = enable_proxy_error + self.config_for_layers: Dict[str, Any] = config_for_layers + self.shared_layer_config: Dict[str, Any] = shared_layer_config + self.modules_to_not_convert = modules_to_not_convert + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + for layer_name, layer_param in self.config_for_layers.items(): + VptqLayerConfig(**layer_param) + if self.enable_proxy_error is True: + raise ValueError("enable_proxy_error should always be False until we support training") + + @dataclass class QuantoConfig(QuantizationConfigMixin): """ diff --git a/tests/quantization/vptq_integration/__init__.py b/tests/quantization/vptq_integration/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/quantization/vptq_integration/test_vptq.py b/tests/quantization/vptq_integration/test_vptq.py new file mode 100644 index 00000000000000..faa9a5879d1dcc --- /dev/null +++ b/tests/quantization/vptq_integration/test_vptq.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, VptqConfig +from transformers.testing_utils import ( + require_accelerate, + require_torch_gpu, + require_torch_multi_gpu, + require_vptq, + slow, + torch_device, +) +from transformers.utils import is_accelerate_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +class VptqConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = VptqConfig() + vptq_orig_config = quantization_config.to_dict() + + self.assertEqual(quantization_config.quant_config, vptq_orig_config["quant_config"]) + + +@slow +@require_torch_gpu +@require_vptq +@require_accelerate +class VptqTest(unittest.TestCase): + model_name = "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v12-k65536-4096-woft" + + input_text = "Hello my name is" + max_new_tokens = 32 + + EXPECTED_OUTPUT = "Hello my name is Sarah and I am a 25 year old woman from the United States. I am a college graduate and I am currently working as a marketing specialist for a small" + + device_map = "cuda" + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, + device_map=cls.device_map, + ) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_raise_if_non_quantized(self): + model_id = "facebook/opt-125m" + quantization_config = VptqConfig() + + with self.assertRaises(ValueError): + _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) + + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto") + + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) + + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_quantized_model_conversion(self): + """ + Simple test that checks if the quantized model has been converted properly + """ + from vptq import VQuantLinear + + from transformers.integrations import replace_with_vptq_linear + + model_id = "facebook/opt-350m" + config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5") + modules_to_not_convert = ["lm_head"] + names = [ + "q_proj", + "k_proj", + "v_proj", + "out_proj", + "fc1", + "fc2", + ] + value = { + "enable_norm": True, + "enable_perm": True, + "group_num": 1, + "group_size": 128, + "indices_as_float": False, + "num_centroids": [-1, 128], + "num_res_centroids": [-1, 128], + "outlier_size": 0, + "vector_lens": [-1, 12], + } + shared_layer_config = {} + for name in names: + shared_layer_config[name] = value + for i in range(24): + modules_to_not_convert.append("model.decoder.layers.{layer_idx}.fc1".format(layer_idx=i)) + layer_configs = {} + layer_configs["model.decoder.project_out"] = value + layer_configs["model.decoder.project_in"] = value + quantization_config = VptqConfig(config_for_layers=layer_configs, shared_layer_config=shared_layer_config) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + nb_linears = 0 + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + nb_linears += 1 + + model, _ = replace_with_vptq_linear(model, quantization_config=quantization_config) + nb_vptq_linear = 0 + for module in model.modules(): + if isinstance(module, VQuantLinear): + nb_vptq_linear += 1 + + self.assertEqual(nb_linears - 1, nb_vptq_linear) + + # Try with `linear_weights_not_to_quantize` + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + quantization_config = VptqConfig(config_for_layers=layer_configs, shared_layer_config=shared_layer_config) + model, _ = replace_with_vptq_linear( + model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert + ) + nb_vptq_linear = 0 + for module in model.modules(): + if isinstance(module, VQuantLinear): + nb_vptq_linear += 1 + # 25 comes from 24 decoder.layers.{layer_idx}.fc1 + # and the last lm_head + self.assertEqual(nb_linears - 25, nb_vptq_linear) From b5a557e5fe2d015bd36214a95878370eaed51571 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:18:15 +0100 Subject: [PATCH 15/23] Reduce CircleCI usage (#35355) * reduce 1 * reduce 1 --------- Co-authored-by: ydshieh --- .circleci/create_circleci_config.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index be8952903e2ce2..71c75dac2ff053 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -55,6 +55,7 @@ def to_dict(self): return { "docker": copy.deepcopy(DEFAULT_DOCKER_IMAGE), + "resource_class": "small", "steps": steps, } @@ -67,9 +68,9 @@ class CircleCIJob: install_steps: List[str] = None marker: Optional[str] = None parallelism: Optional[int] = 0 - pytest_num_workers: int = 12 + pytest_num_workers: int = 8 pytest_options: Dict[str, Any] = None - resource_class: Optional[str] = "2xlarge" + resource_class: Optional[str] = "xlarge" tests_to_run: Optional[List[str]] = None num_test_files_per_worker: Optional[int] = 10 # This should be only used for doctest job! @@ -198,7 +199,6 @@ def job_name(self): docker_image=[{"image": "huggingface/transformers-torch-light"}], marker="not generate", parallelism=6, - pytest_num_workers=8 ) generate_job = CircleCIJob( @@ -206,28 +206,24 @@ def job_name(self): docker_image=[{"image": "huggingface/transformers-torch-light"}], marker="generate", parallelism=6, - pytest_num_workers=8 ) tokenization_job = CircleCIJob( "tokenization", docker_image=[{"image": "huggingface/transformers-torch-light"}], parallelism=8, - pytest_num_workers=16 ) processor_job = CircleCIJob( "processors", docker_image=[{"image": "huggingface/transformers-torch-light"}], parallelism=8, - pytest_num_workers=6 ) tf_job = CircleCIJob( "tf", docker_image=[{"image":"huggingface/transformers-tf-light"}], parallelism=6, - pytest_num_workers=16, ) @@ -235,7 +231,8 @@ def job_name(self): "flax", docker_image=[{"image":"huggingface/transformers-jax-light"}], parallelism=6, - pytest_num_workers=16 + pytest_num_workers=16, + resource_class="2xlarge", ) @@ -244,7 +241,7 @@ def job_name(self): additional_env={"RUN_PIPELINE_TESTS": True}, docker_image=[{"image":"huggingface/transformers-torch-light"}], marker="is_pipeline_test", - parallelism=4 + parallelism=4, ) @@ -253,7 +250,7 @@ def job_name(self): additional_env={"RUN_PIPELINE_TESTS": True}, docker_image=[{"image":"huggingface/transformers-tf-light"}], marker="is_pipeline_test", - parallelism=4 + parallelism=4, ) @@ -270,7 +267,6 @@ def job_name(self): docker_image=[{"image":"huggingface/transformers-examples-torch"}], # TODO @ArthurZucker remove this once docker is easier to build install_steps=["uv venv && uv pip install . && uv pip install -r examples/pytorch/_tests_requirements.txt"], - pytest_num_workers=8, ) @@ -278,7 +274,6 @@ def job_name(self): "examples_tensorflow", additional_env={"OMP_NUM_THREADS": 8}, docker_image=[{"image":"huggingface/transformers-examples-tf"}], - pytest_num_workers=16, ) @@ -293,6 +288,7 @@ def job_name(self): ], marker="is_staging_test", pytest_num_workers=2, + resource_class="medium", ) @@ -305,13 +301,13 @@ def job_name(self): ], pytest_options={"k onnx": None}, pytest_num_workers=1, + resource_class="small", ) exotic_models_job = CircleCIJob( "exotic_models", docker_image=[{"image":"huggingface/transformers-exotic-models"}], - pytest_num_workers=12, parallelism=4, pytest_options={"durations": 100}, ) @@ -330,7 +326,6 @@ def job_name(self): docker_image=[{"image": "huggingface/transformers-torch-light"}], marker="not generate", parallelism=6, - pytest_num_workers=8, ) From eafbb0eca7171436138ad0cbbd1c7f860819510e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 20 Dec 2024 12:08:12 +0100 Subject: [PATCH 16/23] Implement AsyncTextIteratorStreamer for asynchronous streaming (#34931) * Add AsyncTextIteratorStreamer class * export AsyncTextIteratorStreamer * export AsyncTextIteratorStreamer * improve docs * missing import * missing import * doc example fix * doc example output fix * add pytest-asyncio * first attempt at tests * missing import * add pytest-asyncio * fallback to wait_for and raise TimeoutError on timeout * check for TimeoutError * autodoc * reorder imports * fix style --------- Co-authored-by: Arthur Zucker Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/internal/generation_utils.md | 2 + setup.py | 2 + src/transformers/__init__.py | 10 ++- src/transformers/dependency_versions_table.py | 1 + src/transformers/generation/__init__.py | 4 +- src/transformers/generation/streamers.py | 89 +++++++++++++++++++ tests/generation/test_streamers.py | 50 ++++++++++- 7 files changed, 154 insertions(+), 4 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index a54ac432006a84..d8931342ee45f8 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -352,6 +352,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] TextIteratorStreamer +[[autodoc]] AsyncTextIteratorStreamer + ## Caches [[autodoc]] Cache diff --git a/setup.py b/setup.py index c2c0048d6913ec..9e678db9978b14 100644 --- a/setup.py +++ b/setup.py @@ -148,6 +148,7 @@ "pyyaml>=5.1", "pydantic", "pytest>=7.2.0,<8.0.0", + "pytest-asyncio", "pytest-timeout", "pytest-xdist", "python>=3.9.0", @@ -319,6 +320,7 @@ def run(self): extras["testing"] = ( deps_list( "pytest", + "pytest-asyncio", "pytest-rich", "pytest-xdist", "timeout-decorator", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 681bf1a5d16a36..5510ac6c8ad512 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -122,6 +122,7 @@ "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "file_utils": [], "generation": [ + "AsyncTextIteratorStreamer", "CompileConfig", "GenerationConfig", "TextIteratorStreamer", @@ -5055,7 +5056,14 @@ from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin # Generation - from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig + from .generation import ( + AsyncTextIteratorStreamer, + CompileConfig, + GenerationConfig, + TextIteratorStreamer, + TextStreamer, + WatermarkingConfig, + ) from .hf_argparser import HfArgumentParser # Integrations diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 85345cc8e5889d..c370c7a5d7c18c 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -54,6 +54,7 @@ "pyyaml": "pyyaml>=5.1", "pydantic": "pydantic", "pytest": "pytest>=7.2.0,<8.0.0", + "pytest-asyncio": "pytest-asyncio", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", "python": "python>=3.9.0", diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 59d970db15416f..d3eb10c1e6b355 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -26,7 +26,7 @@ "SynthIDTextWatermarkingConfig", "WatermarkingConfig", ], - "streamers": ["TextIteratorStreamer", "TextStreamer"], + "streamers": ["AsyncTextIteratorStreamer", "TextIteratorStreamer", "TextStreamer"], } try: @@ -199,7 +199,7 @@ SynthIDTextWatermarkingConfig, WatermarkingConfig, ) - from .streamers import TextIteratorStreamer, TextStreamer + from .streamers import AsyncTextIteratorStreamer, TextIteratorStreamer, TextStreamer try: if not is_torch_available(): diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index c75b43466af7a8..c78e259db38be8 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from queue import Queue from typing import TYPE_CHECKING, Optional @@ -225,3 +226,91 @@ def __next__(self): raise StopIteration() else: return value + + +class AsyncTextIteratorStreamer(TextStreamer): + """ + Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator. + This is useful for applications that benefit from acessing the generated text asynchronously (e.g. in an + interactive Gradio demo). + + + + The API for the streamer classes is still under development and may change in the future. + + + + Parameters: + tokenizer (`AutoTokenizer`): + The tokenized used to decode the tokens. + skip_prompt (`bool`, *optional*, defaults to `False`): + Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. + timeout (`float`, *optional*): + The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions + in `.generate()`, when it is called in a separate thread. + decode_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the tokenizer's `decode` method. + + Raises: + TimeoutError: If token generation time exceeds timeout value. + + Examples: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, AsyncTextIteratorStreamer + >>> from threading import Thread + >>> import asyncio + + >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") + + >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. + >>> async def main(): + ... # Important: AsyncTextIteratorStreamer must be initialized inside a coroutine! + ... streamer = AsyncTextIteratorStreamer(tok) + ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) + ... thread = Thread(target=model.generate, kwargs=generation_kwargs) + ... thread.start() + ... generated_text = "" + ... async for new_text in streamer: + ... generated_text += new_text + >>> print(generated_text) + >>> asyncio.run(main()) + An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, + ``` + """ + + def __init__( + self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: float | None = None, **decode_kwargs + ): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.text_queue = asyncio.Queue() + self.stop_signal = None + self.timeout = timeout + self.loop = asyncio.get_running_loop() + self.has_asyncio_timeout = hasattr(asyncio, "timeout") + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" + self.loop.call_soon_threadsafe(self.text_queue.put_nowait, text) + if stream_end: + self.loop.call_soon_threadsafe(self.text_queue.put_nowait, self.stop_signal) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + if self.has_asyncio_timeout: + async with asyncio.timeout(self.timeout): + value = await self.text_queue.get() + else: + value = await asyncio.wait_for(self.text_queue.get(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError() + else: + if value == self.stop_signal: + raise StopAsyncIteration() + else: + return value diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py index c82a5e99e0ded0..be8c37334d02fc 100644 --- a/tests/generation/test_streamers.py +++ b/tests/generation/test_streamers.py @@ -17,7 +17,15 @@ from queue import Empty from threading import Thread -from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available +import pytest + +from transformers import ( + AsyncTextIteratorStreamer, + AutoTokenizer, + TextIteratorStreamer, + TextStreamer, + is_torch_available, +) from transformers.testing_utils import CaptureStdout, require_torch, torch_device from ..test_modeling_common import ids_tensor @@ -120,3 +128,43 @@ def test_iterator_streamer_timeout(self): streamer_text = "" for new_text in streamer: streamer_text += new_text + + +@require_torch +@pytest.mark.asyncio(loop_scope="class") +class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase): + async def test_async_iterator_streamer_matches_non_streaming(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False) + greedy_text = tokenizer.decode(greedy_ids[0]) + + streamer = AsyncTextIteratorStreamer(tokenizer) + generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer} + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + streamer_text = "" + async for new_text in streamer: + streamer_text += new_text + + self.assertEqual(streamer_text, greedy_text) + + async def test_async_iterator_streamer_timeout(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001) + generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer} + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + # The streamer will timeout after 0.001 seconds, so TimeoutError will be raised + with self.assertRaises(TimeoutError): + streamer_text = "" + async for new_text in streamer: + streamer_text += new_text From 0d51d6590536e474f253a2060873391616cd015e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Dec 2024 12:09:34 +0100 Subject: [PATCH 17/23] Cleaner attention interfaces (#35342) * cleaner attention interfaces * correctly set the _attn_implementation when adding other functions to it * update * Update modeling_utils.py * CIs --- .../integrations/flash_attention.py | 21 ++++++++++++++----- .../integrations/flex_attention.py | 8 +++++-- .../integrations/sdpa_attention.py | 4 ++++ src/transformers/modeling_utils.py | 9 ++++---- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 1be223f8b079ba..b8407bc29c6a8a 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -29,23 +29,34 @@ def flash_attention_forward( key = key.transpose(1, 2) value = value.transpose(1, 2) + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (usually our RMSNorm modules handle it correctly) + target_dtype = None if query.dtype == torch.float32: - query = query.to(torch.float16) - key = key.to(torch.float16) - value = value.to(torch.float16) + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(module.config, "_pre_quantization_dtype"): + target_dtype = module.config._pre_quantization_dtype + else: + target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype attn_output = _flash_attention_forward( query, key, value, attention_mask, - seq_len, - module.is_causal, + query_length=seq_len, + is_causal=module.is_causal, dropout=dropout, softmax_scale=scaling, sliding_window=sliding_window, softcap=softcap, use_top_left_mask=_use_top_left_mask, + target_dtype=target_dtype, **kwargs, ) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index eacfb2b568b55b..66ffc5638838cb 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -2,10 +2,10 @@ import torch -from ..utils import is_torch_greater_or_equal +from ..utils import is_torch_flex_attn_available -if is_torch_greater_or_equal("2.5"): +if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import flex_attention @@ -37,8 +37,12 @@ def causal_mod(score, b, h, q_idx, kv_idx): score_mod=causal_mod, enable_gqa=True, scale=scaling, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. return_lse=True, ) + # lse is returned in float32 + attention_weights = attention_weights.to(value.dtype) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attention_weights diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 265260c9b79e4c..38701690bf7c2a 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -34,10 +34,14 @@ def sdpa_attention_forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. query = query.contiguous() key = key.contiguous() value = value.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. if is_causal is None: is_causal = causal_mask is None and query.shape[2] > 1 diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9dcd6d758ecbe7..49d086c76e8683 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1474,11 +1474,8 @@ def _autoset_attn_implementation( ) if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ - "eager", - "sdpa", - "flash_attention_2", - "flex_attention", - ]: + "eager" + ] + list(ALL_ATTENTION_FUNCTIONS.keys()): message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' @@ -1540,6 +1537,8 @@ def _autoset_attn_implementation( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) + elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()): + config._attn_implementation = requested_attn_implementation elif isinstance(requested_attn_implementation, dict): config._attn_implementation = None else: From c3a43594b7ae870694f38f6d12074bd498c5477a Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Fri, 20 Dec 2024 03:40:38 -0800 Subject: [PATCH 18/23] Add Tensor Parallel support for Qwen2VL (#35050) feat: add parallel support for qwen2vl --- .../models/qwen2_vl/configuration_qwen2_vl.py | 10 ++++++++++ .../models/qwen2_vl/modeling_qwen2_vl.py | 18 +++++++++--------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 55042327de4ec3..ef98ae5e3f508f 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -163,6 +163,16 @@ class Qwen2VLConfig(PretrainedConfig): model_type = "qwen2_vl" sub_configs = {"vision_config": Qwen2VLVisionConfig} keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2VL` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 10c9b1638548ce..566141d3f75c27 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -547,9 +547,9 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( @@ -631,9 +631,9 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Because the input can be padded, the absolute sequence length depends on the max position id. cos, sin = position_embeddings @@ -750,9 +750,9 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( From 4567ee80572f51859f1454db687cacdf2ec12b13 Mon Sep 17 00:00:00 2001 From: Qizhi Chen Date: Fri, 20 Dec 2024 19:42:40 +0800 Subject: [PATCH 19/23] fix zoedepth initialization error under deepspeed zero3 (#35011) fix zoe bug in deepspeed zero3 --- src/transformers/models/zoedepth/modeling_zoedepth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index 5cbbdcdc04b756..2f4a42d2818005 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -417,7 +417,7 @@ def __init__(self, n_classes=256, act=torch.softmax): self.k = n_classes self.act = act self.register_buffer("k_idx", torch.arange(0, n_classes).view(1, -1, 1, 1), persistent=False) - self.register_buffer("k_minus_1", torch.Tensor([self.k - 1]).view(1, -1, 1, 1), persistent=False) + self.register_buffer("k_minus_1", torch.tensor([self.k - 1]).view(1, -1, 1, 1), persistent=False) def forward(self, probabilities, temperature=1.0, eps=1e-4): """Compute the log binomial distribution for probabilities. From 05de764e9ccaadd8baf4562f607777f821b48dae Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:36:31 +0100 Subject: [PATCH 20/23] Aurevoir PyTorch 1 (#35358) * fix * fix * fix --------- Co-authored-by: ydshieh --- .../workflows/self-nightly-past-ci-caller.yml | 33 ------------------- README.md | 2 +- i18n/README_ar.md | 2 +- i18n/README_de.md | 2 +- i18n/README_es.md | 2 +- i18n/README_fr.md | 2 +- i18n/README_hd.md | 2 +- i18n/README_ja.md | 2 +- i18n/README_ko.md | 2 +- i18n/README_pt-br.md | 2 +- i18n/README_ru.md | 2 +- i18n/README_te.md | 2 +- i18n/README_ur.md | 2 +- i18n/README_vi.md | 2 +- i18n/README_zh-hans.md | 2 +- i18n/README_zh-hant.md | 2 +- .../convert_pytorch_checkpoint_to_tf2.py | 3 +- .../modeling_flax_pytorch_utils.py | 8 ++--- src/transformers/modeling_tf_pytorch_utils.py | 4 +-- src/transformers/modeling_utils.py | 5 ++- .../models/falcon/modeling_falcon.py | 9 ----- .../models/gpt_neo/modeling_gpt_neo.py | 4 --- .../models/phimoe/modeling_phimoe.py | 4 --- .../models/superpoint/modeling_superpoint.py | 3 +- .../models/tapas/modeling_tapas.py | 7 ---- .../models/wav2vec2/modeling_wav2vec2.py | 3 +- src/transformers/pytorch_utils.py | 3 -- src/transformers/trainer.py | 5 ++- src/transformers/trainer_pt_utils.py | 7 +--- src/transformers/training_args.py | 18 ++-------- src/transformers/utils/fx.py | 8 ++--- tests/models/aria/test_modeling_aria.py | 3 +- .../test_modeling_falcon_mamba.py | 6 ---- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 7 ---- tests/models/gptj/test_modeling_gptj.py | 9 ----- tests/models/idefics/test_modeling_idefics.py | 6 ---- .../models/idefics2/test_modeling_idefics2.py | 2 -- .../models/idefics3/test_modeling_idefics3.py | 2 -- tests/models/llava/test_modeling_llava.py | 3 +- .../llava_next/test_modeling_llava_next.py | 3 +- .../test_modeling_llava_next_video.py | 2 -- .../test_modeling_llava_onevision.py | 2 -- tests/models/mamba/test_modeling_mamba.py | 6 ---- tests/models/mamba2/test_modeling_mamba2.py | 6 ---- .../paligemma/test_modeling_paligemma.py | 3 +- tests/models/pixtral/test_modeling_pixtral.py | 3 +- .../qwen2_audio/test_modeling_qwen2_audio.py | 2 -- .../models/qwen2_vl/test_modeling_qwen2_vl.py | 2 -- tests/models/rwkv/test_modeling_rwkv.py | 9 ----- tests/models/tapas/test_modeling_tapas.py | 9 ----- tests/models/tapas/test_tokenization_tapas.py | 9 +---- .../models/vipllava/test_modeling_vipllava.py | 2 -- ...test_pipelines_table_question_answering.py | 15 --------- 53 files changed, 37 insertions(+), 228 deletions(-) diff --git a/.github/workflows/self-nightly-past-ci-caller.yml b/.github/workflows/self-nightly-past-ci-caller.yml index 142399a6366ce6..46d811d4a43394 100644 --- a/.github/workflows/self-nightly-past-ci-caller.yml +++ b/.github/workflows/self-nightly-past-ci-caller.yml @@ -21,39 +21,6 @@ jobs: echo "$(python3 -c 'print(int(${{ github.run_number }}) % 10)')" echo "run_number=$(python3 -c 'print(int(${{ github.run_number }}) % 10)')" >> $GITHUB_OUTPUT - run_past_ci_pytorch_1-13: - name: PyTorch 1.13 - needs: get_number - if: needs.get_number.outputs.run_number == 0 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci'))) - uses: ./.github/workflows/self-past-caller.yml - with: - framework: pytorch - version: "1.13" - sha: ${{ github.sha }} - secrets: inherit - - run_past_ci_pytorch_1-12: - name: PyTorch 1.12 - needs: get_number - if: needs.get_number.outputs.run_number == 1 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci'))) - uses: ./.github/workflows/self-past-caller.yml - with: - framework: pytorch - version: "1.12" - sha: ${{ github.sha }} - secrets: inherit - - run_past_ci_pytorch_1-11: - name: PyTorch 1.11 - needs: get_number - if: needs.get_number.outputs.run_number == 2 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci'))) - uses: ./.github/workflows/self-past-caller.yml - with: - framework: pytorch - version: "1.11" - sha: ${{ github.sha }} - secrets: inherit - run_past_ci_tensorflow_2-11: name: TensorFlow 2.11 needs: get_number diff --git a/README.md b/README.md index c748e675066202..42403f84b885da 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,7 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta ### With pip -This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, and TensorFlow 2.6+. +This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, and TensorFlow 2.6+. You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/i18n/README_ar.md b/i18n/README_ar.md index 8160ec908d4411..c7249ac23d2e7f 100644 --- a/i18n/README_ar.md +++ b/i18n/README_ar.md @@ -245,7 +245,7 @@ limitations under the License. ### باستخدام pip -تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 1.11+، و TensorFlow 2.6+. +تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، و TensorFlow 2.6+. يجب تثبيت 🤗 Transformers في [بيئة افتراضية](https://docs.python.org/3/library/venv.html). إذا كنت غير معتاد على البيئات الافتراضية Python، فراجع [دليل المستخدم](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/i18n/README_de.md b/i18n/README_de.md index ccc9e6111a25f0..78447af41a7a82 100644 --- a/i18n/README_de.md +++ b/i18n/README_de.md @@ -246,7 +246,7 @@ Das Modell selbst ist ein reguläres [PyTorch `nn.Module`](https://pytorch.org/d ### Mit pip -Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ und TensorFlow 2.6+ getestet. +Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ und TensorFlow 2.6+ getestet. Sie sollten 🤗 Transformers in einer [virtuellen Umgebung](https://docs.python.org/3/library/venv.html) installieren. Wenn Sie mit virtuellen Python-Umgebungen nicht vertraut sind, schauen Sie sich den [Benutzerleitfaden](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) an. diff --git a/i18n/README_es.md b/i18n/README_es.md index 5d5ba1b3249785..57eb8117fc0d5d 100644 --- a/i18n/README_es.md +++ b/i18n/README_es.md @@ -222,7 +222,7 @@ El modelo en si es un [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.h ### Con pip -Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ y TensorFlow 2.6+. +Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ y TensorFlow 2.6+. Deberías instalar 🤗 Transformers en un [entorno virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/i18n/README_fr.md b/i18n/README_fr.md index 97b11166b301a1..02714d52bff39b 100644 --- a/i18n/README_fr.md +++ b/i18n/README_fr.md @@ -243,7 +243,7 @@ Le modèle lui-même est un module [`nn.Module` PyTorch](https://pytorch.org/doc ### Avec pip -Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ et TensorFlow 2.6+. +Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ et TensorFlow 2.6+. Vous devriez installer 🤗 Transformers dans un [environnement virtuel](https://docs.python.org/3/library/venv.html). Si vous n'êtes pas familier avec les environnements virtuels Python, consultez le [guide utilisateur](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/i18n/README_hd.md b/i18n/README_hd.md index 17efdd21eb04dc..1541e4df66fcbd 100644 --- a/i18n/README_hd.md +++ b/i18n/README_hd.md @@ -198,7 +198,7 @@ checkpoint: जाँच बिंदु ### पिप का उपयोग करना -इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ और TensorFlow 2.6+ के तहत किया गया है। +इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ और TensorFlow 2.6+ के तहत किया गया है। आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें। diff --git a/i18n/README_ja.md b/i18n/README_ja.md index 3d417098ea314d..fc3d4ae945cefd 100644 --- a/i18n/README_ja.md +++ b/i18n/README_ja.md @@ -256,7 +256,7 @@ Hugging Faceチームによって作られた **[トランスフォーマーを ### pipにて -このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+ でテストされています。 +このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+ でテストされています。 🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。 diff --git a/i18n/README_ko.md b/i18n/README_ko.md index b9502db5dda845..6d6559398e4d17 100644 --- a/i18n/README_ko.md +++ b/i18n/README_ko.md @@ -242,7 +242,7 @@ Transformers에 달린 100,000개의 별을 축하하기 위해, 우리는 커 ### pip로 설치하기 -이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+에서 테스트 되었습니다. +이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+에서 테스트 되었습니다. [가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요. diff --git a/i18n/README_pt-br.md b/i18n/README_pt-br.md index d9248f9a151c36..f865f1b6ed9ca5 100644 --- a/i18n/README_pt-br.md +++ b/i18n/README_pt-br.md @@ -253,7 +253,7 @@ O modelo em si é um [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.ht ### Com pip -Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ e TensorFlow 2.6+. +Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ e TensorFlow 2.6+. Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/i18n/README_ru.md b/i18n/README_ru.md index a359b52d2ccc73..c153474f339000 100644 --- a/i18n/README_ru.md +++ b/i18n/README_ru.md @@ -244,7 +244,7 @@ Hugging Face Hub. Мы хотим, чтобы Transformers позволил ра ### С помощью pip -Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ и TensorFlow 2.6+. +Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ и TensorFlow 2.6+. Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/i18n/README_te.md b/i18n/README_te.md index a9795e9ca326aa..791ed6414f73d2 100644 --- a/i18n/README_te.md +++ b/i18n/README_te.md @@ -246,7 +246,7 @@ limitations under the License. ### పిప్ తో -ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 1.11+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది. +ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 2.0+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది. మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్‌ఫార్మర్‌లను ఇన్‌స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి. diff --git a/i18n/README_ur.md b/i18n/README_ur.md index cc37b5cfc4223d..2d4d7745f68eaf 100644 --- a/i18n/README_ur.md +++ b/i18n/README_ur.md @@ -259,7 +259,7 @@ limitations under the License. #### ‏ pip کے ساتھ -یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 1.11+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔ +یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔ آپ کو 🤗 Transformers کو ایک [ورچوئل ماحول](https://docs.python.org/3/library/venv.html) میں انسٹال کرنا چاہیے۔ اگر آپ Python ورچوئل ماحول سے واقف نہیں ہیں، تو [یوزر گائیڈ](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) دیکھیں۔ diff --git a/i18n/README_vi.md b/i18n/README_vi.md index f523c282b680c4..4f7f67bfce90ff 100644 --- a/i18n/README_vi.md +++ b/i18n/README_vi.md @@ -245,7 +245,7 @@ Chính mô hình là một [Pytorch `nn.Module`](https://pytorch.org/docs/stable ### Sử dụng pip -Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ và TensorFlow 2.6+. +Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ và TensorFlow 2.6+. Bạn nên cài đặt 🤗 Transformers trong một [môi trường ảo Python](https://docs.python.org/3/library/venv.html). Nếu bạn chưa quen với môi trường ảo Python, hãy xem [hướng dẫn sử dụng](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/i18n/README_zh-hans.md b/i18n/README_zh-hans.md index c9ac0357f18f1b..b4d121df0d3200 100644 --- a/i18n/README_zh-hans.md +++ b/i18n/README_zh-hans.md @@ -198,7 +198,7 @@ checkpoint: 检查点 ### 使用 pip -这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下经过测试。 +这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下经过测试。 你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。 diff --git a/i18n/README_zh-hant.md b/i18n/README_zh-hant.md index 87c623ee84a61b..dcafd4958ed1d1 100644 --- a/i18n/README_zh-hant.md +++ b/i18n/README_zh-hant.md @@ -210,7 +210,7 @@ Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換 ### 使用 pip -這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下經過測試。 +這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下經過測試。 你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。 diff --git a/src/transformers/convert_pytorch_checkpoint_to_tf2.py b/src/transformers/convert_pytorch_checkpoint_to_tf2.py index 3875879f0e056d..c3431ad5b2e0ac 100755 --- a/src/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/src/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -106,7 +106,6 @@ XLMWithLMHeadModel, XLNetLMHeadModel, ) - from .pytorch_utils import is_torch_greater_or_equal_than_1_13 logging.set_verbosity_info() @@ -279,7 +278,7 @@ def convert_pt_checkpoint_to_tf( if compare_with_pt_model: tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} state_dict = torch.load( pytorch_checkpoint_path, map_location="cpu", diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 8bbd8587b683f4..8fbba8a1651364 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -63,8 +63,6 @@ def load_pytorch_checkpoint_in_flax_state_dict( else: try: import torch # noqa: F401 - - from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401 except (ImportError, ModuleNotFoundError): logger.error( "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" @@ -73,7 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict( ) raise - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg) logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") @@ -246,13 +244,11 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): import torch - from .pytorch_utils import is_torch_greater_or_equal_than_1_13 - # Load the index flax_state_dict = {} for shard_file in shard_filenames: # load using msgpack utils - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} pt_state_dict = torch.load(shard_file, **weights_only_kwarg) weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} pt_state_dict = { diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 7f1367481ade62..8ec24d6e1872ef 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -180,8 +180,6 @@ def load_pytorch_checkpoint_in_tf2_model( import tensorflow as tf # noqa: F401 import torch # noqa: F401 from safetensors.torch import load_file as safe_load_file # noqa: F401 - - from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401 except ImportError: logger.error( "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " @@ -201,7 +199,7 @@ def load_pytorch_checkpoint_in_tf2_model( if pt_path.endswith(".safetensors"): state_dict = safe_load_file(pt_path) else: - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg) pt_state_dict.update(state_dict) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 49d086c76e8683..a6d4a1cc5b54ed 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -54,7 +54,6 @@ apply_chunking_to_forward, find_pruneable_heads_and_indices, id_tensor_storage, - is_torch_greater_or_equal_than_1_13, prune_conv1d_layer, prune_layer, prune_linear_layer, @@ -476,7 +475,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): error_message += f"\nMissing key(s): {str_unexpected_keys}." raise RuntimeError(error_message) - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg) for shard_file in shard_files: @@ -532,7 +531,7 @@ def load_state_dict( and is_zipfile(checkpoint_file) ): extra_args = {"mmap": True} - weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": weights_only} return torch.load( checkpoint_file, map_location=map_location, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 8d5a224f4f6654..e0e4ff424cb47d 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -38,7 +38,6 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_2_0 from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -815,14 +814,6 @@ def _init_weights(self, module: nn.Module): # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa @classmethod def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig": - # NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0). - if hard_check_only: - if not is_torch_greater_or_equal_than_2_0: - raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.") - - if not is_torch_greater_or_equal_than_2_0: - return config - _is_bettertransformer = getattr(cls, "use_bettertransformer", False) if _is_bettertransformer: return config diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 6763695bfba036..ef23b5d208fd79 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -36,7 +36,6 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -56,9 +55,6 @@ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index cd54b226e1d85c..8f6b092da6e6ad 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -33,7 +33,6 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -51,9 +50,6 @@ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) diff --git a/src/transformers/models/superpoint/modeling_superpoint.py b/src/transformers/models/superpoint/modeling_superpoint.py index 1075de299a9f40..dcdd85460b39bd 100644 --- a/src/transformers/models/superpoint/modeling_superpoint.py +++ b/src/transformers/models/superpoint/modeling_superpoint.py @@ -25,7 +25,6 @@ ) from transformers.models.superpoint.configuration_superpoint import SuperPointConfig -from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 from ...utils import ( ModelOutput, add_start_docstrings, @@ -314,7 +313,7 @@ def _sample_descriptors(keypoints, descriptors, scale: int = 8) -> torch.Tensor: divisor = divisor.to(keypoints) keypoints /= divisor keypoints = keypoints * 2 - 1 # normalize to (-1, 1) - kwargs = {"align_corners": True} if is_torch_greater_or_equal_than_1_13 else {} + kwargs = {"align_corners": True} # [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2] keypoints = keypoints.view(batch_size, 1, -1, 2) descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index b74a27ae5ce589..2ea0d38a23f933 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -31,7 +31,6 @@ from ...pytorch_utils import ( apply_chunking_to_forward, find_pruneable_heads_and_indices, - is_torch_greater_or_equal_than_1_12, prune_linear_layer, ) from ...utils import ( @@ -46,12 +45,6 @@ logger = logging.get_logger(__name__) -if not is_torch_greater_or_equal_than_1_12: - logger.warning( - f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " - "TapasModel. Please upgrade torch." - ) - _CONFIG_FOR_DOC = "TapasConfig" _CHECKPOINT_FOR_DOC = "google/tapas-base" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index e4df2e6ae3b718..5168904a3579d9 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -38,7 +38,6 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -1590,7 +1589,7 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs): cache_dir=cache_dir, ) - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} state_dict = torch.load( weight_path, map_location="cpu", diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index fab1b9118d18d3..95c8748375ce0a 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -34,9 +34,6 @@ is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") -is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") -is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") -is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") # Cache this result has it's a C FFI call which can be pretty time-consuming _torch_distributed_available = torch.distributed.is_available() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4d90c13df825f2..c878d2b345cc31 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -75,7 +75,6 @@ from .processing_utils import ProcessorMixin from .pytorch_utils import ( ALL_LAYERNORM_LAYERS, - is_torch_greater_or_equal_than_1_13, is_torch_greater_or_equal_than_2_3, ) from .tokenization_utils_base import PreTrainedTokenizerBase @@ -2778,7 +2777,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): ) if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt: - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} # If the model is on the GPU, it still works! if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): @@ -2899,7 +2898,7 @@ def _load_best_model(self): or os.path.exists(best_safe_adapter_model_path) ): has_been_loaded = True - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 5f78860fe6c115..da95329e184567 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -56,12 +56,7 @@ import torch_xla.core.xla_model as xm if is_torch_available(): - from .pytorch_utils import is_torch_greater_or_equal_than_2_0 - - if is_torch_greater_or_equal_than_2_0: - from torch.optim.lr_scheduler import LRScheduler - else: - from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + from torch.optim.lr_scheduler import LRScheduler logger = logging.get_logger(__name__) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 6b141cff39e1f7..6950e8e66d3ac1 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -71,8 +71,6 @@ import torch import torch.distributed as dist - from .pytorch_utils import is_torch_greater_or_equal_than_2_0 - if is_accelerate_available(): from accelerate.state import AcceleratorState, PartialState from accelerate.utils import DistributedType @@ -1157,7 +1155,7 @@ class TrainingArguments: }, ) dataloader_prefetch_factor: Optional[int] = field( - default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2, + default=None, metadata={ "help": ( "Number of batches loaded in advance by each worker. " @@ -1702,14 +1700,6 @@ def __post_init__(self): raise ValueError( "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0" ) - elif not is_torch_xpu_available(): - # xpu - from .pytorch_utils import is_torch_greater_or_equal_than_1_12 - - if not is_torch_greater_or_equal_than_1_12: - raise ValueError( - "Your setup doesn't support bf16/xpu. You need torch>=1.12, using Intel XPU/GPU with IPEX installed" - ) if self.fp16 and self.bf16: raise ValueError("At most one of fp16 and bf16 can be True, but not both") @@ -2056,11 +2046,7 @@ def __post_init__(self): if self.use_cpu: self.dataloader_pin_memory = False - if ( - (not is_torch_available() or is_torch_greater_or_equal_than_2_0) - and self.dataloader_num_workers == 0 - and self.dataloader_prefetch_factor is not None - ): + if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None: raise ValueError( "--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e." " when --dataloader_num_workers > 1." diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 101b34182a7309..45fa3d9ca68c51 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -60,7 +60,6 @@ MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_MAPPING_NAMES, ) -from ..pytorch_utils import is_torch_greater_or_equal_than_2_0 from .import_utils import ( ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, @@ -635,10 +634,9 @@ def to_concrete(t): operator.getitem: operator_getitem, } -if is_torch_greater_or_equal_than_2_0: - _MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = ( - torch_nn_functional_scaled_dot_product_attention - ) +_MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = ( + torch_nn_functional_scaled_dot_product_attention +) class HFProxy(Proxy): diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index d3458530ac349e..b6f1da56c6782e 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -45,8 +45,7 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False + if is_vision_available(): from PIL import Image diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 893132f4337dd4..f02e8f167636eb 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -43,9 +43,6 @@ FalconMambaModel, ) from transformers.cache_utils import MambaCache - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 -else: - is_torch_greater_or_equal_than_2_0 = False # Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba @@ -246,9 +243,6 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict -@unittest.skipIf( - not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" -) @require_torch # Copied from transformers.tests.models.mamba.MambaModelTest with Mamba->Falcon,mamba->falcon_mamba,FalconMambaCache->MambaCache class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 1db484c4062c35..281594492500b0 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -37,9 +37,6 @@ GPTBigCodeModel, ) from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention - from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 -else: - is_torch_greater_or_equal_than_1_12 = False class GPTBigCodeModelTester: @@ -504,10 +501,6 @@ class GPTBigCodeMHAModelTest(GPTBigCodeModelTest): multi_query = False -@unittest.skipIf( - not is_torch_greater_or_equal_than_1_12, - reason="`GPTBigCode` checkpoints use `PytorchGELUTanh` which requires `torch>=1.12.0`.", -) @slow @require_torch class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py index afc741cd502dec..50840bbcfaa6dc 100644 --- a/tests/models/gptj/test_modeling_gptj.py +++ b/tests/models/gptj/test_modeling_gptj.py @@ -41,9 +41,6 @@ GPTJForSequenceClassification, GPTJModel, ) - from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 -else: - is_torch_greater_or_equal_than_1_12 = False class GPTJModelTester: @@ -363,15 +360,9 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin test_model_parallel = False test_head_masking = False - @unittest.skipIf( - not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+." - ) def test_torch_fx(self): super().test_torch_fx() - @unittest.skipIf( - not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+." - ) def test_torch_fx_output_loss(self): super().test_torch_fx_output_loss() diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 12004cc3c8ad89..94229b13d2cbfe 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -44,9 +44,6 @@ from transformers import IdeficsForVisionText2Text, IdeficsModel, IdeficsProcessor from transformers.models.idefics.configuration_idefics import IdeficsPerceiverConfig, IdeficsVisionConfig - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 -else: - is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): from PIL import Image @@ -327,7 +324,6 @@ def test_eager_matches_sdpa_generate(self): self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test") -@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @require_torch class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else () @@ -594,7 +590,6 @@ def test_sdpa_can_dispatch_non_composite_models(self): pass -@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @require_torch class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase): all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else () @@ -818,7 +813,6 @@ def test_sdpa_can_dispatch_non_composite_models(self): pass -@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @require_torch @require_vision class IdeficsModelIntegrationTest(TestCasePlus): diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 83e125c07c15bc..974628c8b4324f 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -48,8 +48,6 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): from PIL import Image diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index 5bfd4c3f3c0e83..c25fa1180649fa 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -40,8 +40,6 @@ Idefics3ForConditionalGeneration, Idefics3Model, ) -else: - is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): from PIL import Image diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 3d08ab35e0f630..b4a959a00d2a0c 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -43,8 +43,7 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False + if is_vision_available(): from PIL import Image diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index c258ce96b94e48..14b0fb8cc07db7 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -48,8 +48,7 @@ import torch from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches -else: - is_torch_greater_or_equal_than_2_0 = False + if is_vision_available(): from PIL import Image diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index a6fb341ff9bf56..c431f91bf5102f 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -48,8 +48,6 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): from PIL import Image diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index a217eee2c70671..6965d2033ec730 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -48,8 +48,6 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): from PIL import Image diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index d432dfa93df487..455022140f7c5b 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -38,9 +38,6 @@ MambaModel, ) from transformers.models.mamba.modeling_mamba import MambaCache - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 -else: - is_torch_greater_or_equal_than_2_0 = False class MambaModelTester: @@ -239,9 +236,6 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict -@unittest.skipIf( - not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" -) @require_torch class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index c2ef68f2614ea5..17cbdc1e8d51dd 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -37,9 +37,6 @@ Mamba2Model, ) from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 -else: - is_torch_greater_or_equal_than_2_0 = False class Mamba2ModelTester: @@ -214,9 +211,6 @@ def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args, self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3)) -@unittest.skipIf( - not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" -) @require_torch class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_torch_available() else () diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 5ffea7ffe55087..f973e1211dc081 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -40,8 +40,7 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False + if is_vision_available(): from PIL import Image diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py index 0c36cb5a4e0554..3e5667caf45e3e 100644 --- a/tests/models/pixtral/test_modeling_pixtral.py +++ b/tests/models/pixtral/test_modeling_pixtral.py @@ -33,8 +33,7 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False + if is_vision_available(): pass diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 4806ec2c72d339..8974d6923b391c 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -41,8 +41,6 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False class Qwen2AudioModelTester: diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 93ed33ae774458..2c27e1a03a647c 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -47,8 +47,6 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): from PIL import Image diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index 5e82956e3efa6c..0bc5c2de070135 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -33,9 +33,6 @@ RwkvForCausalLM, RwkvModel, ) - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 -else: - is_torch_greater_or_equal_than_2_0 = False class RwkvModelTester: @@ -231,9 +228,6 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict -@unittest.skipIf( - not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" -) @require_torch class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else () @@ -440,9 +434,6 @@ def test_left_padding_compatibility(self): pass -@unittest.skipIf( - not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" -) @slow class RWKVIntegrationTests(unittest.TestCase): def setUp(self): diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index 4ee159d6bddd1d..05618f4a4efd8c 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -60,9 +60,6 @@ reduce_mean, reduce_sum, ) - from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 -else: - is_torch_greater_or_equal_than_1_12 = False class TapasModelTester: @@ -411,7 +408,6 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict -@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( @@ -578,7 +574,6 @@ def prepare_tapas_batch_inputs_for_training(): return table, queries, answer_coordinates, answer_text, float_answer -@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch class TapasModelIntegrationTest(unittest.TestCase): @cached_property @@ -930,10 +925,6 @@ def test_inference_classification_head(self): self.assertTrue(torch.allclose(outputs.logits, expected_tensor, atol=0.05)) -# Below: tests for Tapas utilities which are defined in modeling_tapas.py. -# These are based on segmented_tensor_test.py of the original implementation. -# URL: https://github.com/google-research/tapas/blob/master/tapas/models/segmented_tensor_test.py -@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch class TapasUtilitiesTest(unittest.TestCase): def _prepare_tables(self): diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index 0a911f7182b4a0..9a3a2578fd16b3 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -23,7 +23,7 @@ import pandas as pd from parameterized import parameterized -from transformers import AddedToken, is_torch_available +from transformers import AddedToken from transformers.models.tapas.tokenization_tapas import ( VOCAB_FILES_NAMES, BasicTokenizer, @@ -45,12 +45,6 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings -if is_torch_available(): - from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 -else: - is_torch_greater_or_equal_than_1_12 = False - - @require_tokenizers @require_pandas class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @@ -1048,7 +1042,6 @@ def test_token_type_ids(self): # Do the same test as modeling common. self.assertIn(0, output["token_type_ids"][0]) - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch @slow def test_torch_encode_plus_sent_to_model(self): diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index 4f501fc10a028f..8286b3c94fb9da 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -41,8 +41,6 @@ if is_torch_available(): import torch -else: - is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): from PIL import Image diff --git a/tests/pipelines/test_pipelines_table_question_answering.py b/tests/pipelines/test_pipelines_table_question_answering.py index 9481ab200063f8..e2141dc7cc2f66 100644 --- a/tests/pipelines/test_pipelines_table_question_answering.py +++ b/tests/pipelines/test_pipelines_table_question_answering.py @@ -20,7 +20,6 @@ AutoTokenizer, TableQuestionAnsweringPipeline, TFAutoModelForTableQuestionAnswering, - is_torch_available, pipeline, ) from transformers.testing_utils import ( @@ -33,12 +32,6 @@ ) -if is_torch_available(): - from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 -else: - is_torch_greater_or_equal_than_1_12 = False - - @is_pipeline_test class TQAPipelineTests(unittest.TestCase): # Putting it there for consistency, but TQA do not have fast tokenizer @@ -150,7 +143,6 @@ def test_small_model_tf(self): }, ) - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch def test_small_model_pt(self, torch_dtype="float32"): model_id = "lysandre/tiny-tapas-random-wtq" @@ -253,12 +245,10 @@ def test_small_model_pt(self, torch_dtype="float32"): }, ) - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch def test_small_model_pt_fp16(self): self.test_small_model_pt(torch_dtype="float16") - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch def test_slow_tokenizer_sqa_pt(self, torch_dtype="float32"): model_id = "lysandre/tiny-tapas-random-sqa" @@ -378,7 +368,6 @@ def test_slow_tokenizer_sqa_pt(self, torch_dtype="float32"): }, ) - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch def test_slow_tokenizer_sqa_pt_fp16(self): self.test_slow_tokenizer_sqa_pt(torch_dtype="float16") @@ -505,7 +494,6 @@ def test_slow_tokenizer_sqa_tf(self): }, ) - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @slow @require_torch def test_integration_wtq_pt(self, torch_dtype="float32"): @@ -551,7 +539,6 @@ def test_integration_wtq_pt(self, torch_dtype="float32"): ] self.assertListEqual(results, expected_results) - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @slow @require_torch def test_integration_wtq_pt_fp16(self): @@ -606,7 +593,6 @@ def test_integration_wtq_tf(self): ] self.assertListEqual(results, expected_results) - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @slow @require_torch def test_integration_sqa_pt(self, torch_dtype="float32"): @@ -632,7 +618,6 @@ def test_integration_sqa_pt(self, torch_dtype="float32"): ] self.assertListEqual(results, expected_results) - @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @slow @require_torch def test_integration_sqa_pt_fp16(self): From 40292aa4e95abb827847f77318a32efc1e76d973 Mon Sep 17 00:00:00 2001 From: Jiwoong Date: Fri, 20 Dec 2024 22:37:04 +0900 Subject: [PATCH 21/23] bugfix: torch.export failure caused by `_make_causal_mask` (#35291) * bugfix: torch.export failure caused by `_make_causal_mask` Recent changes in torch dynamo prevent mutations on tensors converted with aten::_to_copy. To address this, we can clone such tensor before performing in-place operation `masked_fill_` only when the code is being compiled by torch dynamo. (relevant issue: https://github.com/pytorch/pytorch/issues/127571) * chore: use `is_torchdynamo_compiling` instead of `torch._dynamo.is_compiling` --- src/transformers/modeling_attn_mask_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 4319c021cb2bc3..09fc77e46b07ed 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -169,6 +169,10 @@ def _make_causal_mask( diagonal = past_key_values_length - sliding_window - 1 context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy + # See https://github.com/pytorch/pytorch/issues/127571 + if is_torchdynamo_compiling(): + mask = mask.clone() mask.masked_fill_(context_mask, torch.finfo(dtype).min) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) From 34ad1bd287d26565a23119bdd22632eb0b573c73 Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Fri, 20 Dec 2024 23:04:36 +0900 Subject: [PATCH 22/23] update codecarbon (#35243) * update codecarbon * replace directly-specified-test-dirs with tmp_dir * Revert "replace directly-specified-test-dirs with tmp_dir" This reverts commit 310a6d962ec83db3f6d4f96daeeba5c6746f736c. * revert the change of .gitignore * Update .gitignore --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 9e678db9978b14..a78bb20dd0a4b0 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,7 @@ "av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream. "beautifulsoup4", "blobfile", - "codecarbon==1.2.0", + "codecarbon>=2.8.1", "cookiecutter==1.7.3", "dataclasses", "datasets!=2.5.0", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index c370c7a5d7c18c..6a737b805a456c 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -7,7 +7,7 @@ "av": "av==9.2.0", "beautifulsoup4": "beautifulsoup4", "blobfile": "blobfile", - "codecarbon": "codecarbon==1.2.0", + "codecarbon": "codecarbon>=2.8.1", "cookiecutter": "cookiecutter==1.7.3", "dataclasses": "dataclasses", "datasets": "datasets!=2.5.0", From 6fae2a84aebe99de035db9faddc88c08696f0705 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 20 Dec 2024 15:10:43 +0100 Subject: [PATCH 23/23] Update test fetcher when we want to test all (#35364) * [test-all] * style * [test-all] * [test_all] * [test_all] * style --- utils/tests_fetcher.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 906e85e1de61a5..c641ccb21e2984 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -995,9 +995,7 @@ def _print_list(l) -> str: def infer_tests_to_run( - output_file: str, - diff_with_last_commit: bool = False, - filter_models: bool = False, + output_file: str, diff_with_last_commit: bool = False, filter_models: bool = False, test_all: bool = False ): """ The main function called by the test fetcher. Determines the tests to run from the diff. @@ -1018,7 +1016,11 @@ def infer_tests_to_run( Whether or not to filter the tests to core models only, when a file modified results in a lot of model tests. """ - modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit) + if not test_all: + modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit) + else: + modified_files = [str(k) for k in PATH_TO_TESTS.glob("*/*") if str(k).endswith(".py") and "test_" in str(k)] + print("\n### test_all is TRUE, FETCHING ALL FILES###\n") print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}") # Create the map that will give us all impacted modules. @@ -1230,5 +1232,6 @@ def create_test_list_from_filter(full_test_list, out_path): args.output_file, diff_with_last_commit=diff_with_last_commit, filter_models=False, + test_all=commit_flags["test_all"], ) filter_tests(args.output_file, ["repo_utils"])