diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index d719b8224e3bb4..626cce2c821855 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -31,7 +31,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -39,7 +39,9 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -57,6 +59,93 @@ _CONFIG_FOR_DOC = "DiffLlamaConfig" +class DiffLlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[DiffLlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`DiffLlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class DiffLlamaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -481,23 +570,23 @@ def forward( class DiffLlamaRMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, hidden_size, eps=1e-6): + """ + DiffLlamaRMSNorm is equivalent to T5LayerNorm + """ super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.zeros(dim)) + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()) - # Llama does x.to(float16) * w whilst DiffLlama is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - output = output * (1.0 + self.weight.float()) - return output.type_as(x) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.eps}" + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" DIFFLLAMA_ATTENTION_CLASSES = { @@ -511,7 +600,9 @@ class DiffLlamaDecoderLayer(nn.Module): def __init__(self, config: DiffLlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size + self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.mlp = DiffLlamaMLP(config) self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -525,6 +616,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -542,6 +634,9 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model @@ -559,6 +654,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -722,6 +818,7 @@ def __init__(self, config: DiffLlamaConfig): [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False if getattr(config, "pretraining_tp", 1) != 1: @@ -749,6 +846,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -770,9 +868,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False # noqa: F841 + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True # noqa: F841 + return_legacy_cache = True if past_key_values is None: past_key_values = DynamicCache() else: @@ -788,22 +886,16 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - - # embed positions hidden_states = inputs_embeds - # normalized - # DiffLlama downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) - hidden_states = hidden_states * normalizer + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -824,6 +916,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -834,6 +927,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 40e728850d9cb8..ab6ffe651584ce 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -24,20 +24,16 @@ from ...cache_utils import Cache, StaticCache from ...modeling_flash_attention_utils import _flash_attention_forward -from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( is_flash_attn_greater_or_equal_2_10, logging, ) from ..gemma.modeling_gemma import GemmaForCausalLM from ..llama.modeling_llama import ( - LlamaDecoderLayer, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, - LlamaPreTrainedModel, - LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, @@ -52,6 +48,10 @@ _CONFIG_FOR_DOC = "DiffLlamaConfig" +class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding): + pass + + class DiffLlamaMLP(MistralMLP): pass @@ -419,6 +419,10 @@ def forward( return attn_output, None, past_key_value +class DiffLlamaModel(LlamaModel): + pass + + class DiffLlamaForCausalLM(GemmaForCausalLM): pass @@ -436,7 +440,7 @@ class DiffLlamaForTokenClassification(LlamaForTokenClassification): __all__ = [ - "DiffLlamaPreTrainedModel", + "DiffLlamaPreTrainedModel", # noqa: F822 "DiffLlamaModel", "DiffLlamaForCausalLM", "DiffLlamaForSequenceClassification",