diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index c617420a5896de..55d63fa4514a88 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -279,15 +279,10 @@ def _compute_longrope_parameters( # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # values to compute the default attention scaling factor, instead of using `factor`. if hasattr(config, "original_max_position_embeddings"): - if seq_len and seq_len < config.original_max_position_embeddings: - expanded_max_position_embeddings = config.original_max_position_embeddings - else: - expanded_max_position_embeddings = config.max_position_embeddings max_position_embeddings = config.original_max_position_embeddings - factor = expanded_max_position_embeddings / max_position_embeddings + factor = config.max_position_embeddings / config.original_max_position_embeddings else: max_position_embeddings = config.max_position_embeddings - expanded_max_position_embeddings = max_position_embeddings * factor # Sets the attention factor as suggested in the paper if attention_factor is None: @@ -297,7 +292,7 @@ def _compute_longrope_parameters( attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) # Compute the inverse frequencies -- scaled based on the target sequence length - if expanded_max_position_embeddings > max_position_embeddings: + if seq_len and seq_len > max_position_embeddings: ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) else: ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 6481d6f3c434c7..3547a1a0b49d50 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -723,11 +723,7 @@ def _init_weights(self, module): class AriaTextRotaryEmbedding(nn.Module): - def __init__( - self, - config: AriaTextConfig, - device=None, - ): + def __init__(self, config: AriaTextConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index c89d8d7853008d..e17e5ef712cf7a 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -120,11 +120,7 @@ def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device= class BambaRotaryEmbedding(nn.Module): - def __init__( - self, - config: BambaConfig, - device=None, - ): + def __init__(self, config: BambaConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 7b8b9547ac1c33..58a1891f00ebc2 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/cohere/modular_cohere.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_cohere.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Cohere team. All rights reserved. # @@ -20,13 +26,10 @@ # This file is based on the LLama model definition file in transformers -"""PyTorch Cohere model.""" -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -34,31 +37,21 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_cohere import CohereConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "CohereConfig" @@ -79,49 +72,17 @@ def forward(self, hidden_states): return hidden_states.to(input_dtype) -ALL_LAYERNORM_LAYERS.append(CohereLayerNorm) - - class CohereRotaryEmbedding(nn.Module): - # Note: the forward pass of this RoPE is slightly different from Llama's, resulting in different `sin`/`cos` for - # the same parameterization. The differences are highlighted with a comment. - - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[CohereConfig] = None, - ): + def __init__(self, config: CohereConfig, device=None): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`CohereRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -161,7 +122,7 @@ def forward(self, x, position_ids): device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.repeat_interleave(freqs, 2, dim=-1) # This line differs from Llama's implementation + emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() sin = emb.sin() @@ -172,6 +133,60 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class CohereMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + def rotate_half(x): # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] @@ -210,36 +225,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) -class CohereMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - # Ignore copy - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - class CohereAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -247,162 +232,57 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - self.use_qk_norm = config.use_qk_norm - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads - self.q_norm = CohereLayerNorm(hidden_size=(self.num_heads, self.head_dim), eps=config.layer_norm_eps) - self.k_norm = CohereLayerNorm( - hidden_size=(self.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps + self.q_norm = CohereLayerNorm( + hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - if self.use_qk_norm: - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" + self.k_norm = CohereLayerNorm( + hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps ) - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere -# TODO cyril: modular -class CohereFlashAttention2(CohereAttention): - """ - Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - # Ignore copy def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - if self.use_qk_norm: + if self.use_qk_norm: # main diff from Llama query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -412,169 +292,37 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (CohereLayerNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class CohereSdpaAttention(CohereAttention): - """ - Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `CohereAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "CohereModel is using CohereSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - if self.use_qk_norm: - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -COHERE_ATTENTION_CLASSES = { - "eager": CohereAttention, - "flash_attention_2": CohereFlashAttention2, - "sdpa": CohereSdpaAttention, -} + return attn_output, attn_weights class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - - self.self_attn = COHERE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - + self.self_attn = CohereAttention(config=config, layer_idx=layer_idx) self.mlp = CohereMLP(config) self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) @@ -583,11 +331,12 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -595,13 +344,13 @@ def forward( attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -613,7 +362,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states_attention, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -622,6 +371,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) # Fully Connected @@ -631,19 +381,16 @@ def forward( hidden_states = residual + hidden_states_attention + hidden_states_mlp outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs COHERE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings etc.). + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage @@ -661,7 +408,6 @@ def forward( "The bare Cohere Model outputting raw hidden-states without any specific head on top.", COHERE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Cohere class CoherePreTrainedModel(PreTrainedModel): config_class = CohereConfig base_model_prefix = "model" @@ -754,6 +500,10 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -761,8 +511,6 @@ def _init_weights(self, module): "The bare Cohere Model outputting raw hidden-states without any specific head on top.", COHERE_START_DOCSTRING, ) -# copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere, LLAMA->COHERE -# TODO cyril: modular class CohereModel(CoherePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`] @@ -771,7 +519,6 @@ class CohereModel(CoherePreTrainedModel): config: CohereConfig """ - # Ignore copy def __init__(self, config: CohereConfig): super().__init__(config) self.padding_idx = config.pad_token_id @@ -800,7 +547,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1023,11 +770,13 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} - # Ignore copy def __init__(self, config): super().__init__(config) self.model = CohereModel(config) @@ -1035,6 +784,7 @@ def __init__(self, config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.logit_scale = config.logit_scale self.tie_word_embeddings = config.tie_word_embeddings + # Initialize weights and apply final processing self.post_init() @@ -1056,7 +806,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - # Ignore copy @add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1064,7 +813,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1073,7 +822,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1123,16 +872,17 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - logits = logits * self.logit_scale + logits = logits * self.logit_scale # main diff from Llama loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py new file mode 100644 index 00000000000000..b3483103461e9a --- /dev/null +++ b/src/transformers/models/cohere/modular_cohere.py @@ -0,0 +1,393 @@ +# coding=utf-8 +# Copyright 2024 Cohere team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is based on the LLama model definition file in transformers + +"""PyTorch Cohere model.""" + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import LossKwargs, logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, + LlamaRotaryEmbedding, + eager_attention_forward, +) +from .configuration_cohere import CohereConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "CohereConfig" + + +class CohereLayerNorm(nn.Module): + def __init__(self, hidden_size=None, eps=1e-5, bias=False): + """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim""" + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(CohereLayerNorm) + + +class CohereRotaryEmbedding(LlamaRotaryEmbedding): + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + dtype = q.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + + +class CohereMLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + +class CohereAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads + self.q_norm = CohereLayerNorm( + hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps + ) + self.k_norm = CohereLayerNorm( + hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + if self.use_qk_norm: # main diff from Llama + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class CohereDecoderLayer(nn.Module): + def __init__(self, config: CohereConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = CohereAttention(config=config, layer_idx=layer_idx) + self.mlp = CohereMLP(config) + self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states_attention, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # Fully Connected + hidden_states_mlp = self.mlp(hidden_states) + + # Add everything together + hidden_states = residual + hidden_states_attention + hidden_states_mlp + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class CohereModel(LlamaModel): + def __init__(self, config: CohereConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = CohereRotaryEmbedding(config=config) + self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class CohereForCausalLM(LlamaForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model = CohereModel(config) + self.logit_scale = config.logit_scale + self.tie_word_embeddings = config.tie_word_embeddings + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >> from transformers import AutoTokenizer, CohereForCausalLM + + >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01") + >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") + + >> prompt = "Hey, are you conscious? Can you talk to me?" + >> inputs = tokenizer(prompt, return_tensors="pt") + + >> # Generate + >> generate_ids = model.generate(inputs.input_ids, max_length=30) + >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = logits * self.logit_scale # main diff from Llama + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 1ffa4bffddc3df..d0463573d2b1d8 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -28,10 +28,13 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, @@ -46,50 +49,20 @@ logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "Cohere2Config" class Cohere2RotaryEmbedding(nn.Module): - # Note: the forward pass of this RoPE is slightly different from Llama's, resulting in different `sin`/`cos` for - # the same parameterization. The differences are highlighted with a comment. - - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Cohere2Config] = None, - ): + def __init__(self, config: Cohere2Config, device=None): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Cohere2RotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -129,7 +102,7 @@ def forward(self, x, position_ids): device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.repeat_interleave(freqs, 2, dim=-1) # This line differs from Llama's implementation + emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() sin = emb.sin() @@ -157,6 +130,18 @@ def forward(self, hidden_states): return hidden_states.to(input_dtype) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def rotate_half(x): # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] @@ -195,18 +180,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def eager_attention_forward( config: Cohere2Config, query: torch.Tensor, @@ -425,7 +398,6 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - # Ignore copy def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -436,7 +408,6 @@ def __init__(self, config: Cohere2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Cohere2Attention(config, layer_idx) - self.mlp = Cohere2MLP(config) self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) self.config = config @@ -521,7 +492,8 @@ def forward( COHERE2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings etc.). + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage @@ -874,11 +846,13 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere2 +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} - # Ignore copy def __init__(self, config: Cohere2Config): super().__init__(config) self.model = Cohere2Model(config) @@ -886,6 +860,7 @@ def __init__(self, config: Cohere2Config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.logit_scale = config.logit_scale self.tie_word_embeddings = config.tie_word_embeddings + # Initialize weights and apply final processing self.post_init() @@ -907,7 +882,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - # Ignore copy @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -915,7 +889,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -924,7 +898,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -974,16 +948,17 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - logits = logits * self.logit_scale + logits = logits * self.logit_scale # main diff from Llama loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e2ea12b03fe434..f2611b16c34558 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -92,11 +92,7 @@ def forward(self, x): class GemmaRotaryEmbedding(nn.Module): - def __init__( - self, - config: GemmaConfig, - device=None, - ): + def __init__(self, config: GemmaConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 67fc6c86a3bac6..0bf2b154f9c221 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -324,11 +324,7 @@ def forward( class Gemma2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Gemma2Config, - device=None, - ): + def __init__(self, config: Gemma2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 95ad0d9719951d..87a1c75e5084ed 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -62,7 +62,6 @@ def __init__(self, config): self.config = config self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) - self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: @@ -256,11 +255,7 @@ def extra_repr(self): class GlmRotaryEmbedding(nn.Module): - def __init__( - self, - config: GlmConfig, - device=None, - ): + def __init__(self, config: GlmConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 2e045e149d95de..49b6b7ed2df237 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -309,11 +309,7 @@ def forward( class GraniteRotaryEmbedding(nn.Module): - def __init__( - self, - config: GraniteConfig, - device=None, - ): + def __init__(self, config: GraniteConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5be33c26414cd7..61e7fa0f6ca429 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -80,11 +80,7 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): - def __init__( - self, - config: LlamaConfig, - device=None, - ): + def __init__(self, config: LlamaConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 90c38895b4280b..8c0d8af3ece737 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -270,11 +270,7 @@ def forward( class MistralRotaryEmbedding(nn.Module): - def __init__( - self, - config: MistralConfig, - device=None, - ): + def __init__(self, config: MistralConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 84ed327d9be920..a726b69fb6688f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -392,11 +392,7 @@ def forward( class MixtralRotaryEmbedding(nn.Module): - def __init__( - self, - config: MixtralConfig, - device=None, - ): + def __init__(self, config: MixtralConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 11d3d99f4f72c9..8f43f303dc8869 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -274,11 +274,7 @@ def forward( class OlmoRotaryEmbedding(nn.Module): - def __init__( - self, - config: OlmoConfig, - device=None, - ): + def __init__(self, config: OlmoConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 49ae798e7f1101..dc5893cad55473 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -275,11 +275,7 @@ def forward( class Olmo2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Olmo2Config, - device=None, - ): + def __init__(self, config: Olmo2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 477896decd5318..b4079c3ef7d417 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -270,11 +270,7 @@ def forward( class PhiRotaryEmbedding(nn.Module): - def __init__( - self, - config: PhiConfig, - device=None, - ): + def __init__(self, config: PhiConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 908fd982b9c73c..551f610908bb5d 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi3/modular_phi3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # @@ -13,34 +19,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Phi-3 model.""" -import math -import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -53,186 +56,24 @@ _CONFIG_FOR_DOC = "Phi3Config" -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 -class Phi3RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Phi3RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -# copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 -# TODO cyril: modular -class Phi3RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): +class Phi3MLP(nn.Module): + def __init__(self, config): super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): - def __init__(self, dim, config, device=None): - warnings.warn( - "The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please" - " use Phi3LongRoPEScaledRotaryEmbedding instead.", - FutureWarning, - ) - super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) - - self.short_factor = config.rope_scaling["short_factor"] - self.long_factor = config.rope_scaling["long_factor"] - self.original_max_position_embeddings = config.original_max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - seq_len = torch.max(position_ids) + 1 - if seq_len > self.original_max_position_embeddings: - ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) - else: - ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) - inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim - self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - scale = self.max_position_embeddings / self.original_max_position_embeddings - if scale <= 1.0: - scaling_factor = 1.0 - else: - scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) - cos = emb.cos() * scaling_factor - sin = emb.sin() * scaling_factor - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): - def __init__(self, dim, config, device=None): - warnings.warn( - "The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers", - FutureWarning, - ) - super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) - - self.short_factor = config.rope_scaling["short_factor"] - self.long_factor = config.rope_scaling["long_factor"] - self.original_max_position_embeddings = config.original_max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - seq_len = torch.max(position_ids) + 1 - if seq_len > self.original_max_position_embeddings: - ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) - else: - ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) - - inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim - self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) - - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - - scale = self.max_position_embeddings / self.original_max_position_embeddings - if scale <= 1.0: - scaling_factor = 1.0 - else: - scaling_factor = 0.1 * math.log(scale) + 1.0 - - cos = emb.cos() * scaling_factor - sin = emb.sin() * scaling_factor - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding): - def __init__(self, dim, config, device=None): - super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) - - self.short_factor = config.rope_scaling["short_factor"] - self.long_factor = config.rope_scaling["long_factor"] - self.original_max_position_embeddings = config.original_max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - seq_len = seq_len or torch.max(position_ids) + 1 - if seq_len > self.original_max_position_embeddings: - ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) - else: - ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) - - inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim - self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) - - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) - scale = self.max_position_embeddings / self.original_max_position_embeddings - if scale <= 1.0: - scaling_factor = 1.0 - else: - scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) - cos = emb.cos() * scaling_factor - sin = emb.sin() * scaling_factor - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return self.down_proj(up_states) -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -240,7 +81,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -268,26 +108,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class Phi3MLP(nn.Module): - def __init__(self, config): - super().__init__() - - self.config = config - self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) - self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) - - self.activation_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - up_states = self.gate_up_proj(hidden_states) - - gate, up_states = up_states.chunk(2, dim=-1) - up_states = up_states * self.activation_fn(gate) - - return self.down_proj(up_states) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -300,6 +120,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Phi3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -307,374 +153,117 @@ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.original_max_position_embeddings = config.original_max_position_embeddings - self.rope_theta = config.rope_theta - self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.rope_scaling is None: - self.rotary_emb = Phi3RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - if scaling_type == "longrope": - self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) qkv = self.qkv_proj(hidden_states) - query_pos = self.num_heads * self.head_dim + query_pos = self.config.num_attention_heads * self.head_dim query_states = qkv[..., :query_pos] key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights += causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Phi3FlashAttention2(Phi3Attention): - """ - Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # Phi3FlashAttention2 attention does not support output_attentions - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - qkv = self.qkv_proj(hidden_states) - query_pos = self.num_heads * self.head_dim - query_states = qkv[..., :query_pos] - key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] - value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len - ) - - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_dropout = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - - if query_states.dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.qkv_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=attn_dropout, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 -# TODO cyril: modular -class Phi3SdpaAttention(Phi3Attention): - """ - Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Phi3Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - qkv = self.qkv_proj(hidden_states) - query_pos = self.num_heads * self.head_dim - query_states = qkv[..., :query_pos] - key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] - value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) -PHI3_ATTENTION_CLASSES = { - "eager": Phi3Attention, - "flash_attention_2": Phi3FlashAttention2, - "sdpa": Phi3SdpaAttention, -} + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Phi3DecoderLayer(nn.Module): def __init__(self, config: Phi3Config, layer_idx: int): super().__init__() - - self.config = config - self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) - + self.hidden_size = config.hidden_size + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) self.mlp = Phi3MLP(config) self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) - self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -685,26 +274,25 @@ def forward( position_ids (`torch.LongTensor` of shape `({0})`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -712,26 +300,95 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, ) - - hidden_states = residual + self.resid_attn_dropout(attn_outputs) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + self.resid_mlp_dropout(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, config: Phi3Config, device=None): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + if self.rope_type == "longrope": + long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=config.original_max_position_embeddings + 1, **self.rope_kwargs + ) + self.register_buffer("long_inv_freq", long_inv_freq, persistent=False) + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq = self.inv_freq + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + inv_freq = self.inv_freq + elif self.rope_type == "longrope": + seq_len = torch.max(position_ids) + 1 + if seq_len > self.config.original_max_position_embeddings: + inv_freq = self.long_inv_freq + + # Core RoPE block + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + PHI3_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -750,7 +407,7 @@ def forward( @add_start_docstrings( - "The bare Phi-3 model outputting raw hidden-states without any specific head on top.", + "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", PHI3_START_DOCSTRING, ) class Phi3PreTrainedModel(PreTrainedModel): @@ -758,11 +415,12 @@ class Phi3PreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Phi3DecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True - + _supports_quantized_cache = True + _supports_static_cache = True _version = "0.0.5" def _init_weights(self, module): @@ -853,7 +511,7 @@ def _init_weights(self, module): @add_start_docstrings( - "The bare Phi-3 model outputting raw hidden-states without any specific head on top.", + "The bare Phi3 Model outputting raw hidden-states without any specific head on top.", PHI3_START_DOCSTRING, ) class Phi3Model(Phi3PreTrainedModel): @@ -870,14 +528,13 @@ def __init__(self, config: Phi3Config): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.embed_dropout = nn.Dropout(config.embd_pdrop) self.layers = nn.ModuleList( [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = config._attn_implementation self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.rotary_emb = Phi3RotaryEmbedding(config=config) self.gradient_checkpointing = False + # Initialize weights and apply final processing self.post_init() @@ -893,54 +550,43 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -950,12 +596,14 @@ def forward( hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -969,6 +617,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -979,13 +628,12 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -995,18 +643,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -1017,6 +660,14 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None @@ -1084,7 +735,6 @@ def _update_causal_mask( return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3 def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1152,10 +802,13 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 def __init__(self, config): super().__init__(config) self.model = Phi3Model(config) @@ -1165,31 +818,24 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings def get_input_embeddings(self): return self.model.embed_tokens - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings def set_input_embeddings(self, value): self.model.embed_tokens = value - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings def get_output_embeddings(self): return self.lm_head - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder def set_decoder(self, decoder): self.model = decoder - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder def get_decoder(self): return self.model - # Ignore copy @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1197,7 +843,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1206,7 +852,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1227,27 +873,17 @@ def forward( ```python >>> from transformers import AutoTokenizer, Phi3ForCausalLM - >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") - >>> prompt = "This is an example script ." + >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - if ( - use_cache - and self.config.rope_scaling - and cache_position is not None - and cache_position[0] == self.config.original_max_position_embeddings - ): - logger.warning( - f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." - ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1265,6 +901,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1273,7 +911,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1329,7 +967,7 @@ def prepare_inputs_for_generation( @add_start_docstrings( """ - The [`Phi3Model`] with a sequence classification head on top (linear layer). + The Phi3 Model transformer with a sequence classification head on top (linear layer). [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. @@ -1342,7 +980,6 @@ def prepare_inputs_for_generation( """, PHI3_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs class Phi3ForSequenceClassification(Phi3PreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1381,7 +1018,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - model_outputs = self.model( + transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1392,7 +1029,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - hidden_states = model_outputs[0] + hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: @@ -1420,44 +1057,48 @@ def forward( loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) if not return_dict: - output = (pooled_logits,) + model_outputs[1:] + output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, - past_key_values=model_outputs.past_key_values, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ - [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. + The Phi3 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. """, PHI3_START_DOCSTRING, ) -# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs class Phi3ForTokenClassification(Phi3PreTrainedModel): - def __init__(self, config: Phi3Config): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.model = Phi3Model(config) - if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout - elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1467,16 +1108,16 @@ def __init__(self, config: Phi3Config): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1485,38 +1126,32 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - model_outputs = self.model( + outputs = self.model( input_ids, - past_key_values=past_key_values, attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - - hidden_states = model_outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) loss = None if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - batch_size, seq_length = labels.shape - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) - ) + loss = self.loss_function(logits, labels, self.config) if not return_dict: - output = (logits,) + model_outputs[2:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py new file mode 100644 index 00000000000000..af44efadd341cf --- /dev/null +++ b/src/transformers/models/phi3/modular_phi3.py @@ -0,0 +1,304 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-3 model.""" + +from typing import Callable, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralPreTrainedModel, + MistralRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from .configuration_phi3 import Phi3Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" +_CONFIG_FOR_DOC = "Phi3Config" + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Phi3DecoderLayer(MistralDecoderLayer): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.config = config + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Phi3RotaryEmbedding(MistralRotaryEmbedding): + def __init__(self, config: Phi3Config, device=None): + super().__init__(config, device) + if self.rope_type == "longrope": + long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=config.original_max_position_embeddings + 1, **self.rope_kwargs + ) + self.register_buffer("long_inv_freq", long_inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq = self.inv_freq + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + inv_freq = self.inv_freq + elif self.rope_type == "longrope": + seq_len = torch.max(position_ids) + 1 + if seq_len > self.config.original_max_position_embeddings: + inv_freq = self.long_inv_freq + + # Core RoPE block + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3PreTrainedModel(MistralPreTrainedModel): + _version = "0.0.5" + + +class Phi3ForCausalLM(MistralForCausalLM, Phi3PreTrainedModel): + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = Phi3PreTrainedModel().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + num_logits_to_keep=num_logits_to_keep, + **kwargs, + ) + return model_inputs + + +class Phi3ForSequenceClassification(MistralForSequenceClassification): + pass + + +class Phi3ForTokenClassification(MistralForTokenClassification): + pass diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 36fb1ddf1390ac..c0e0d88431c268 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -283,11 +283,7 @@ def forward( class Qwen2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Qwen2Config, - device=None, - ): + def __init__(self, config: Qwen2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3b4fdbcb81ccc4..a510ca1e1ca26a 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -274,11 +274,7 @@ def forward( class Starcoder2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Starcoder2Config, - device=None, - ): + def __init__(self, config: Starcoder2Config, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 2c5557dfd67aae..6ec663c6636fa6 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -459,6 +459,9 @@ def test_model_rope_scaling_short_long_factor(self, scaling_type): "long_factor": [5.0 for _ in range(n_factors)], } input_tensor = ids_tensor([1, 4090], config.vocab_size) + # Make sure we don't have padding tokens. If this is the case, then the actual number of "true" tokens may be shorter + # than `config.original_max_position_embeddings + 5`, invalidating this test + input_tensor[input_tensor == config.pad_token_id] += 1 model = Phi3ForCausalLM(config) model.to(torch_device) model.eval()