diff --git a/examples/modular-transformers/configuration_my_new_model.py b/examples/modular-transformers/configuration_my_new_model.py index 3c7848e6956..aa0aac55ba9 100644 --- a/examples/modular-transformers/configuration_my_new_model.py +++ b/examples/modular-transformers/configuration_my_new_model.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_my_new_model.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_my_new_model.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation @@ -158,6 +158,13 @@ def __init__( new_param=0, **kwargs, ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -187,11 +194,3 @@ def __init__( self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) self.new_param = new_param - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/examples/modular-transformers/configuration_my_new_model2.py b/examples/modular-transformers/configuration_my_new_model2.py index 5fef1cecc70..f05ace94b62 100644 --- a/examples/modular-transformers/configuration_my_new_model2.py +++ b/examples/modular-transformers/configuration_my_new_model2.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_my_new_model2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_my_new_model2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation @@ -11,106 +11,6 @@ class MyNewModel2Config(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`MyNewModel2Model`]. It is used to instantiate an MyNewModel2 - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the MyNewModel2-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the MyNewModel2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`MyNewModel2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. MyNewModel2 1 supports up to 2048 tokens, - MyNewModel2 2 up to 4096, CodeMyNewModel2 up to 16384. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to - understand more about it. This value is necessary to ensure exact reproducibility of the pretraining - results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'my_new_model23'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'my_new_model23'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'my_new_model23'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'my_new_model23'. Scaling factor applied to high frequency components of the RoPE - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. - head_dim (`int`, *optional*): - The attention head dimension. If None, it will default to hidden_size // num_heads This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma-7B. @@ -121,7 +21,6 @@ class MyNewModel2Config(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GemmaModel`] - ```python >>> from transformers import GemmaModel, GemmaConfig >>> # Initializing a Gemma gemma-7b style configuration diff --git a/examples/modular-transformers/configuration_new_model.py b/examples/modular-transformers/configuration_new_model.py index 8bc8ef52cee..4d164fe3e75 100644 --- a/examples/modular-transformers/configuration_new_model.py +++ b/examples/modular-transformers/configuration_new_model.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_new_model.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_new_model.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Example where we only want to overwrite the defaults of an init from ...configuration_utils import PretrainedConfig @@ -104,6 +104,13 @@ def __init__( attention_dropout=0.0, **kwargs, ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -121,14 +128,6 @@ def __init__( self.attention_bias = attention_bias self.attention_dropout = attention_dropout - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - @property def num_heads(self): return self.num_attention_heads diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index b5b1fc6aec8..ed7e3c64d7a 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -1,26 +1,24 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_dummy.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dummy.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, -) +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -33,59 +31,6 @@ logger = logging.get_logger(__name__) -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - class DummyRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -193,40 +138,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 4] - x2 = x[..., x.shape[-1] // 4 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class DummyMLP(nn.Module): def __init__(self, config): super().__init__() @@ -261,6 +172,40 @@ def forward(self, x): return down_proj +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 4] + x2 = x[..., x.shape[-1] // 4 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -423,6 +368,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -507,6 +453,7 @@ def forward( sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -871,6 +818,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -952,6 +900,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -1011,10 +960,9 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1023,13 +971,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1043,6 +990,63 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index 611d7be961f..e18e6a19e8a 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -1,27 +1,20 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_dummy_bert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dummy_bert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import os from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from packaging import version from torch import nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import ( - _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from ...modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, -) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -40,79 +33,6 @@ _CONFIG_FOR_DOC = "DummyBertConfig" -def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") - except ValueError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - class DummyBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -706,6 +626,79 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output +def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + class DummyBertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -871,26 +864,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -1027,7 +1000,6 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return super().forward(input_ids) return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 49cdd274162..16f9e525a05 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -1,25 +1,20 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_my_new_model2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_my_new_model2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, - SequenceClassifierOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -30,6 +25,9 @@ from .configuration_my_new_model2 import MyNewModel2Config +logger = logging.get_logger(__name__) + + class MyNewModel2RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -50,9 +48,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -logger = logging.get_logger(__name__) - - class MyNewModel2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -448,59 +443,6 @@ def forward( return attn_output, attn_weights, past_key_value -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - MY_NEW_MODEL2_ATTENTION_CLASSES = { "eager": MyNewModel2Attention, "flash_attention_2": MyNewModel2FlashAttention2, @@ -893,10 +835,9 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -905,13 +846,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -925,10 +865,67 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( """ @@ -1019,27 +1016,8 @@ def forward( loss = None if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 640331ace1d..4556308f1ea 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -8,7 +8,6 @@ from typing import ClassVar, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from ...cache_utils import Cache, StaticCache @@ -18,92 +17,15 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, replace_return_docstrings, ) -from .configuration_new_task_model import NewTaskModelConfig - - -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_new_task_model import NewTaskModelConfig -logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "NewTaskModelConfig" -# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position -# But NewTaskModel has no causal mask on prefix -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, - is_training: bool = False, - token_type_ids: torch.Tensor = None, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - is_training (`bool`): - Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels` - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - if is_training: - causal_mask = torch.triu(causal_mask, diagonal=1) - else: - causal_mask[:, :sequence_length] = 0.0 - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - # we are training thus we need to create a full mask on the image + prefix but causal on suffix - if is_training: - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) - return causal_mask - - @dataclass class NewTaskModelCausalLMOutputWithPast(ModelOutput): """ @@ -182,12 +104,12 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["NewTaskModelMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = False _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True - _supports_sdpa = True _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of NewTaskModelisn't meant for training from scratch - only @@ -210,14 +132,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - NEW_TASK_MODEL_INPUTS_DOCSTRING = r""" Args: @@ -301,11 +215,8 @@ def __init__(self, config): self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = NewTaskModelMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self._attn_implementation = config._attn_implementation - language_model = AutoModelForCausalLM.from_config( - config=config.text_config, attn_implementation=self._attn_implementation - ) + language_model = AutoModelForCausalLM.from_config(config=config.text_config) if language_model._tied_weights_keys is not None: self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] @@ -344,6 +255,11 @@ def tie_weights(self): def _update_causal_mask( self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + using_static_cache = isinstance(past_key_values, StaticCache) dtype = inputs_embeds.dtype min_dtype = torch.finfo(dtype).min @@ -388,6 +304,22 @@ def _update_causal_mask( ) return causal_mask + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + image_features = image_features / (self.config.hidden_size**0.5) + return image_features + @add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -426,9 +358,9 @@ def forward( ```python >>> from PIL import Image >>> import requests - >>> from transformers import AutoProcessor, NewTaskModelForNewTask + >>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration - >>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf") + >>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/NewTaskModel-test-224px-hf") >>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf") >>> prompt = "answer en Where is the cow standing?" @@ -484,6 +416,7 @@ def prepare_inputs_for_generation( num_logits_to_keep=None, **kwargs, ): + # Overwritten -- custom `position_ids` and `pixel_values` handling model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -493,33 +426,10 @@ def prepare_inputs_for_generation( cache_position=cache_position, use_cache=use_cache, num_logits_to_keep=num_logits_to_keep, + token_type_ids=token_type_ids, **kwargs, ) - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - dtype = self.get_output_embeddings().weight.dtype - min_dtype = torch.finfo(dtype).min - - model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_length(), - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=batch_size, - ) - - model_inputs["token_type_ids"] = token_type_ids - # position_ids in NewTaskModel are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py new file mode 100644 index 00000000000..e50cf60c3a4 --- /dev/null +++ b/examples/modular-transformers/modeling_roberta.py @@ -0,0 +1,1014 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_roberta.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_roberta.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from packaging import version + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + get_torch_version, + logging, +) +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-roberta/roberta-base-uncased" +_CONFIG_FOR_DOC = "RobertaConfig" + + +class RobertaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, config.pad_token_id + ) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + self.pad_token_id = config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class RobertaSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class RobertaSdpaSelfAttention(RobertaSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from RobertaSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class RobertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": RobertaSelfAttention, + "sdpa": RobertaSdpaSelfAttention, +} + + +class RobertaAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = RobertaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class RobertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class RobertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class RobertaLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RobertaAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RobertaAttention(config, position_embedding_type="absolute") + self.intermediate = RobertaIntermediate(config) + self.output = RobertaOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class RobertaEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class RobertaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +def load_tf_weights_in_roberta(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class RobertaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaConfig + load_tf_weights = load_tf_weights_in_roberta + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RobertaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Roberta Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class RobertaModel(RobertaPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"] + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RobertaEmbeddings(config) + self.encoder = RobertaEncoder(config) + + self.pooler = RobertaPooler(config) if add_pooling_layer else None + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index d91bdb1820c..7df04bcc2a9 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -1,26 +1,24 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_super.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_super.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, -) +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -33,59 +31,6 @@ logger = logging.get_logger(__name__) -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - class SuperRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -123,7 +68,7 @@ def __init__( if config is None: logger.warning_once( "`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" + "`config` argument. All other arguments will be removed in v4.46" ) self.rope_kwargs = { "rope_type": rope_type, @@ -193,40 +138,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class SuperMLP(nn.Module): def __init__(self, config): super().__init__() @@ -261,6 +172,40 @@ def forward(self, x): return down_proj +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -302,7 +247,7 @@ def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = SuperRotaryEmbedding(config=self.config) def forward( @@ -314,7 +259,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -349,7 +294,7 @@ def forward( logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -422,7 +367,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -449,7 +395,7 @@ def forward( logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -507,6 +453,7 @@ def forward( sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -535,7 +482,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -569,7 +516,7 @@ def forward( logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -644,7 +591,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -790,7 +737,8 @@ def _init_weights(self, module): returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - - a [`~cache_utils.Cache`] instance; + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. @@ -916,10 +864,9 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -928,13 +875,12 @@ def _update_causal_mask( ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, - min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -948,6 +894,63 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask diff --git a/examples/modular-transformers/modular_roberta.py b/examples/modular-transformers/modular_roberta.py index a3e0218f932..13dca4845c1 100644 --- a/examples/modular-transformers/modular_roberta.py +++ b/examples/modular-transformers/modular_roberta.py @@ -13,8 +13,5 @@ def __init__(self, config): class RobertaModel(BertModel): - def __init__(self, config): + def __init__(self, config, add_pooling_layer=True): super().__init__(self, config) - # Error out here. Why? Because `RobertaEmbeddings` is defined but not used. - # no, because it's defined, and RobertaModel should use RobertaEmbedding - # here if initialized that way it won't use the new embedding. diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9a4de1022c5..fa3fadc4349 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -49,7 +48,10 @@ from .configuration_gemma import GemmaConfig +logger = logging.get_logger(__name__) + _CHECKPOINT_FOR_DOC = "google/gemma-7b" +_CONFIG_FOR_DOC = "GemmaConfig" class GemmaRMSNorm(nn.Module): @@ -72,9 +74,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -logger = logging.get_logger(__name__) - - class GemmaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -624,9 +623,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -_CONFIG_FOR_DOC = "GemmaConfig" - - GEMMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 74976bdd340..45006b8ca2f 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -19,8 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6d61c47619f..626e5537fc0 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -23,7 +23,6 @@ import torch import torch.nn as nn -import torch.utils.checkpoint from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache @@ -40,6 +39,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, logging, @@ -48,7 +48,15 @@ from .configuration_gemma2 import Gemma2Config +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + _CHECKPOINT_FOR_DOC = "google/gemma2-7b" +_CONFIG_FOR_DOC = "Gemma2Config" class Gemma2RMSNorm(nn.Module): @@ -86,9 +94,6 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -logger = logging.get_logger(__name__) - - class Gemma2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -198,12 +203,12 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None self.rotary_emb = Gemma2RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, @@ -495,12 +500,12 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.config = config self.is_sliding = not bool(layer_idx % 2) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window - self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -638,9 +643,6 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): return config -_CONFIG_FOR_DOC = "Gemma2Config" - - GEMMA2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -865,6 +867,7 @@ def forward( attentions=all_self_attns, ) + @torch.no_grad() def _update_causal_mask( self, attention_mask: torch.Tensor, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 5f8eaf89ed9..248ec402179 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -24,7 +24,6 @@ import torch import torch.nn as nn -import torch.utils.checkpoint from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache @@ -50,7 +49,10 @@ from .configuration_glm import GlmConfig +logger = logging.get_logger(__name__) + _CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" +_CONFIG_FOR_DOC = "GlmConfig" class GlmRMSNorm(nn.Module): @@ -121,7 +123,16 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: return self.down_proj(up_states) -logger = logging.get_logger(__name__) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def rotate_half(x): @@ -172,18 +183,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - class GlmAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -608,9 +607,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -_CONFIG_FOR_DOC = "GlmConfig" - - GLM_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index c9f12391666..0a8f383380d 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -24,7 +24,6 @@ from typing import Any, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss @@ -347,104 +346,6 @@ def _init_weights(self, module): module.bias.data.zero_() -INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See - [`InstructBlipVideoProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. -""" - -INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See - [`InstructBlipVideoProcessor.__call__`] for details. - - qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided - to serve as text prompt, which the Q-Former model will encode. - - Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - - qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be - provided to serve as text prompt, which the language model can continue. - - Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an - encoder-decoder language model (like T5) is used. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) - - decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - Only relevant in case an encoder-decoder language model (like T5) is used. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. -""" - - class InstructBlipVideoEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -531,6 +432,24 @@ def forward( ) +INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See + [`InstructBlipVideoProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. +""" + + class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel): main_input_name = "pixel_values" config_class = InstructBlipVideoVisionConfig @@ -1268,6 +1187,87 @@ def forward( ) +INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See + [`InstructBlipVideoProcessor.__call__`] for details. + + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be + provided to serve as text prompt, which the language model can continue. + + Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an + encoder-decoder language model (like T5) is used. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) + + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. +""" + + @add_start_docstrings( """ InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index a2328c1d2d9..73118f4bfcd 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -25,7 +25,6 @@ import numpy as np import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -33,12 +32,7 @@ from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next_video import LlavaNextVideoConfig @@ -48,113 +42,6 @@ _CONFIG_FOR_DOC = "LlavaNextVideoConfig" -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): - """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. - - Args: - image_size (`tuple`): - The size of the input image in the format (width, height). - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - tuple: The shape of the image patch grid in the format (width, height). - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" - ) - image_size = image_size.tolist() - - height, width = select_best_resolution(image_size, grid_pinpoints) - return height // patch_size, width // patch_size - - -def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): - """ - Calculate the number of patches after the preprocessing for images of any resolution. - - Args: - image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): - The size of the input image in the format (height, width). ? - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - int: the number of patches - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") - image_size = image_size.tolist() - - best_resolution = select_best_resolution(image_size, grid_pinpoints) - height, width = best_resolution - num_patches = 0 - # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 - for i in range(0, height, patch_size): - for j in range(0, width, patch_size): - num_patches += 1 - # add the base patch - num_patches += 1 - return num_patches - - -def unpad_image(tensor, original_size): - """ - Unpads a PyTorch tensor of a padded and resized image. - - Args: - tensor (`torch.Tensor`): - The image tensor, assumed to be of shape (num_channels, height, width). - original_size (`tuple`): - The original size of the image (height, width). - - Returns: - `torch.Tensor`: The unpadded image tensor. - """ - if not isinstance(original_size, (list, tuple)): - if not isinstance(original_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" - ) - original_size = original_size.tolist() - original_height, original_width = original_size - current_height, current_width = tensor.shape[1:] - - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width / original_width - new_height = int(round(original_height * scale_factor, 7)) - padding = (current_height - new_height) // 2 - unpadded_tensor = tensor[:, padding : current_height - padding, :] - else: - scale_factor = current_height / original_height - new_width = int(round(original_width * scale_factor, 7)) - padding = (current_width - new_width) // 2 - unpadded_tensor = tensor[:, :, padding : current_width - padding] - - return unpadded_tensor - - @dataclass class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): """ @@ -304,6 +191,113 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """ + Calculate the number of patches after the preprocessing for images of any resolution. + + Args: + image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): + The size of the input image in the format (height, width). ? + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + int: the number of patches + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = 0 + # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + num_patches += 1 + # add the base patch + num_patches += 1 + return num_patches + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + if not isinstance(original_size, (list, tuple)): + if not isinstance(original_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(round(original_height * scale_factor, 7)) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(round(original_width * scale_factor, 7)) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 2025140bb6e..8018afa7244 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -30,7 +30,6 @@ from ...configuration_utils import PretrainedConfig from ...utils import ( logging, - replace_return_docstrings, ) from ..auto import CONFIG_MAPPING @@ -309,7 +308,6 @@ def get_video_features( video_features = torch.split(video_features, frames, dim=0) return video_features - @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class="LlavaNextVideoConfig") def forward( self, input_ids: torch.LongTensor = None, diff --git a/utils/check_modular_conversion.py b/utils/check_modular_conversion.py index 09b237c1e6c..86af396e03a 100644 --- a/utils/check_modular_conversion.py +++ b/utils/check_modular_conversion.py @@ -4,6 +4,8 @@ import logging from io import StringIO +from create_dependency_mapping import find_priority_list + # Console for rich printing from modular_model_converter import convert_modular_file from rich.console import Console @@ -69,7 +71,7 @@ def compare_files(modular_file_path, fix_and_overwrite=False): if args.files == ["all"]: args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) non_matching_files = 0 - for modular_file_path in args.files: + for modular_file_path in find_priority_list(args.files): non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite) if non_matching_files and not args.fix_and_overwrite: diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index bda143c2577..b1dfa18a7a9 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -17,13 +17,14 @@ import importlib import os import re +from abc import ABC, abstractmethod from collections import defaultdict, deque -from typing import Dict, List, Optional, Set +from typing import Dict, Set import libcst as cst from check_copies import run_ruff from create_dependency_mapping import find_priority_list -from libcst import ClassDef, CSTTransformer, CSTVisitor +from libcst import ClassDef, CSTVisitor from libcst import matchers as m from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider @@ -34,13 +35,6 @@ logger = logging.get_logger(__name__) -# This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the -# value from the dependency is used, then mapped to current name convention, resulting in wrong value. -# The corresponding mapped value is used to define the file target for the assignment -ASSIGNMENTS_TO_KEEP = { - "_CHECKPOINT_FOR_DOC": "modeling", -} - AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from {relative_path}. # Do NOT edit this file manually as any edits will be overwritten by the generation of @@ -61,137 +55,23 @@ def get_module_source_from_name(module_name: str) -> str: return source_code -class ClassFinder(CSTVisitor): - """A visitor class which analyses a module, creating a mapping of dependencies between classes and functions. - For example if the visited code has - ```python3 - def init_value(): return 1 - - class LlamaModel(PreTrainedModel): - def __init__(self): - super().__init__(self) - self.value = init_value() - ``` - then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]} - - The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by - checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the - dependence parent -> child. - - When visiting such nodes, we update the dependency of the parent node, to take into account the visited node. - - All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX. - """ - - METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) +def preserve_case_replace(text, patterns: dict, default_name: str): + # Create a regex pattern to match all variations + regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) + compiled_regex = re.compile(regex_pattern, re.IGNORECASE) - def __init__(self, python_module: cst.Module): - # fmt: off - self.python_module: cst.Module = python_module # original cst.Module being visited - self.classes: Dict[str, cst.ClassDef] = {} # stores a mapping from classname to the cst.Node - self.imports = {} # stores all import statements - self.function_def = {} # stores global scope function definition - self.assignments = {} # LLAMA_DOCSTRING - self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] - self.first_lvl_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] - # fmt: on - - def _update_class_dependency(self, name, value): - """Update the dependency mapping for `name` with `value` by appending the previous - dependencies to the new `value`. - """ - dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value}) - self.first_lvl_dependency_mapping[name] = dep - - dep = set(self.class_dependency_mapping.get(value, set())) - dep |= set(self.class_dependency_mapping.get(name, {})) | set({value}) - self.class_dependency_mapping[name] = dep - - def visit_ClassDef(self, node: ClassDef) -> None: - """We don't have non global scope class defs in transformers. Here we add the inheritance dependencies""" - self.classes[node.name.value] = node - for k in node.bases: # deal with inheritance - base_name = self.python_module.code_for_node(k) - self._update_class_dependency(node.name.value, base_name) + def replace(match): + word = match.group(0) + result = patterns.get(word, default_name) + return result - def visit_SimpleStatementLine(self, node): - """ - Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements - are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. - """ - if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( - self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() - ): - left_hand_side = node.body[0].targets[0].target - if hasattr(left_hand_side, "value"): - if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys(): - self.assignments[left_hand_side.value] = node - else: - for idx, target in enumerate(list(left_hand_side.elements)): - if target.value.value not in ASSIGNMENTS_TO_KEEP.keys(): - self.assignments[target.value.value] = node.body[0].value.elements[idx].value - if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): - self.imports[node.body[0].names] = node + return compiled_regex.sub(replace, text) - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.function_def[node.name.value] = node - def leave_If(self, node): - for stmt in node.body.body: - if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): - self.imports[stmt.body[0].names] = node - - def leave_Name(self, node): - if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys(): - parent = self.get_metadata(cst.metadata.ScopeProvider, node) - if not isinstance(parent, cst.metadata.scope_provider.GlobalScope): - self._update_class_dependency(parent._name_prefix.split(".")[0], node.value) - - def leave_Arg(self, node): - if m.matches(node.value, m.Name()): - parent = self.get_metadata(ParentNodeProvider, node) - if m.matches(parent, m.ClassDef()) and parent.bases: - self._update_class_dependency(parent.name.value, node.value.value) - - def leave_Dict(self, node): - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent, m.Assign(targets=[m.AssignTarget()])): - name = parent.targets[0].target.value - if name in self.assignments: - for k in node.elements: - dep_name = k.value.value - if dep_name in self.classes: - self._update_class_dependency(name, dep_name) - - def leave_Decorator(self, node): - if hasattr(node.decorator, "args"): - for k in node.decorator.args: - if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))): # and k.value.func.value.value: - if k.value.func.value.value not in self.assignments: - raise ValueError( - f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}" - ) - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - scope = self.get_metadata(cst.metadata.ScopeProvider, node) - name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value - self._update_class_dependency(name, k.value.func.value.value) - elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments: - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - scope = self.get_metadata(cst.metadata.ScopeProvider, node) - name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value - self._update_class_dependency(name, k.value.value) - - def leave_Module(self, node): - """When leaving the module, we store the position of each global scoped node (Assigns, function def and class def) - to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this. - """ - self.global_nodes = {**self.assignments, **self.classes, **self.function_def} - # now sort the class dependency_mapping based on the position of the nodes - self.class_start_line = {} - for id, node in self.global_nodes.items(): - self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line +def convert_to_camelcase(text, old_name: str, default_old_name: str): + # Regex pattern to match consecutive uppercase letters and lowercase the first set + result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1) + return result class ReplaceNameTransformer(m.MatcherDecoratableTransformer): @@ -210,8 +90,6 @@ def __init__( new_name, given_old_name=None, given_new_name=None, - old_class_name: str = None, - new_class_name: str = None, ): super().__init__() self.old_name = old_name @@ -232,70 +110,17 @@ def __init__( self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "") if self.default_old_name.isupper(): self.default_old_name = self.default_old_name.capitalize() - if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns: - # In last recourse, when the suffix of the new class is not the same as the old class, - # and if the old and new classes start with the default name, we keep the default class name - # and replace the old suffix with the new one. - # Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration` - # where a model extends another model, but is used for a different task. - if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name): - self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :] - - def preserve_case_replace(self, text): - # Create a regex pattern to match all variations - regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys()) - compiled_regex = re.compile(regex_pattern, re.IGNORECASE) - - def replace(match): - word = match.group(0) - result = self.patterns.get(word, self.default_name) - return result - - return compiled_regex.sub(replace, text) - - def convert_to_camelcase(self, text): - # Regex pattern to match consecutive uppercase letters and lowercase the first set - result = re.sub( - rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1 - ) - return result @m.leave(m.Name() | m.SimpleString() | m.Comment()) def replace_name(self, original_node, updated_node): if re.findall(r"# Copied from", updated_node.value): return cst.RemoveFromParent() - update = self.preserve_case_replace(updated_node.value) + update = preserve_case_replace(updated_node.value, self.patterns, self.default_name) return updated_node.with_changes(value=update) def leave_ClassDef(self, original_node, updated_node): - return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value))) - - -def find_classes_in_file( - module: cst.Module, - old_id="llama", - new_id="gemma", - given_old_name=None, - given_new_name=None, - old_class_name=None, - new_class_name=None, -): - """Helper function to rename and then parse a source file using the ClassFinder""" - transformer = ReplaceNameTransformer( - old_id, - new_id, - given_old_name=given_old_name, - given_new_name=given_new_name, - old_class_name=old_class_name, - new_class_name=new_class_name, - ) - new_module = module.visit(transformer) - - wrapper = MetadataWrapper(new_module) - - class_finder = ClassFinder(new_module) - wrapper.visit(class_finder) - return class_finder + new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name) + return updated_node.with_changes(name=cst.Name(new_name)) DOCSTRING_NODE = m.SimpleStatementLine( @@ -412,13 +237,12 @@ def merge_docstrings(original_docstring, updated_docstring): class SuperTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None): + def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None): self.python_module = python_module self.original_methods = original_methods self.updated_methods = updated_methods self.all_assign_target = {} self.deleted_targets = {} # child node can delete some arguments - self.class_name = class_name self.all_bases = all_bases or [] self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) @@ -437,7 +261,6 @@ def update_body(self, existing_body, new_statements): if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): target = self.python_module.code_for_node(node.body[0].target) self.deleted_targets[target] = node - continue for stmt in existing_body: if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): @@ -447,6 +270,9 @@ def update_body(self, existing_body, new_statements): continue if target in self.all_assign_target: stmt = self.all_assign_target[target] + # Skip the docstring (will be added later on, at the beginning) + elif m.matches(stmt, DOCSTRING_NODE): + continue comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() deduplicated_new_body.append(stmt) @@ -456,17 +282,47 @@ def update_body(self, existing_body, new_statements): code = self.python_module.code_for_node(node) comment_less_code = re.sub(r"#.*", "", code).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if ( - node not in deduplicated_new_body - and "super().__init__" not in comment_less_code - and comment_less_code not in existing_nodes - ): + if node not in deduplicated_new_body and comment_less_code not in existing_nodes: if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): - # HACK here to fix the pos_init() that has to be last we kinda do this. - deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:] + deduplicated_new_body.append(node) existing_nodes.add(comment_less_code) + + deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body) + return deduplicated_new_body + def _fix_post_init_location(self, new_body: list[cst.CSTNode]): + """Fix the location of the `post_init()` in the new body, if we added statements after the call to + `super()` (it needs to be the very last statement called)""" + # Fix the post_init() that has to be last + for i, node in enumerate(new_body): + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if "self.post_init(" in comment_less_code and i < len(new_body) - 1: + # Remove it and add it again at the end + new_body.pop(i) + new_body.append(node) + break + return new_body + + def _fix_init_location(self, new_body): + """Fix the location of the `super().__init__()` in the new body, if we had new statements before it.""" + start_index = 0 + for i, node in enumerate(new_body): + if m.matches(node, DOCSTRING_NODE) and i == start_index: + start_index += 1 + continue + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if "super().__init__" in comment_less_code and i > start_index: + # Remove it and add it again at the top after the docstrings + node = new_body.pop(i) + new_body = new_body[:start_index] + [node] + new_body[start_index:] + break + return new_body + def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: """Updates the body of the input `node`'s `func_name` function by replacing calls to super().func_name() with the source code of the parent class' `func_name`. @@ -479,10 +335,11 @@ def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CS new_body = [] has_super_call = False - for expr in node.body: + for i, expr in enumerate(node.body): if is_call_to_super(expr, func_name): has_super_call = True - new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body)) + new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :])) + new_body = self._fix_init_location(new_body) else: expr = expr.visit(self.transformer) if m.matches(expr, DOCSTRING_NODE): @@ -524,11 +381,463 @@ def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> c return updated_node -def replace_call_to_super( - class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str] -): +def find_all_dependencies( + dependency_mapping: Dict[str, set], + start_entity: str | None = None, + initial_dependencies: set | None = None, + initial_checked_dependencies: set | None = None, + return_parent: bool = False, +) -> list | set: + """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of + BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. + + Args: + dependency_mapping (`Dict[str, set]`): + A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names, + a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called + in `foo`'s definition. + start_entity (str | None, *optional*): + A key of `dependency_mapping`, indicating from which entity to start the search. + initial_dependencies (set | None, *optional*): + If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue + from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. + initial_checked_dependencies (set | None, *optional*): + If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. + return_parent (bool, *optional*): + If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note + that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. + Returns: + A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`. + + Example: + Given the following structure in the `modular_xxx.py` file: + ``` + def foo1(): + pass + + def foo2(): + pass + + def bar(): + foo1() + + def foobar(): + bar() + foo2() + + class MyLayer(SomeOtherModelLayer): + def forward(...): + foobar() + ``` + and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: + ``` + dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} + find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) + >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] + ``` + That is, all the functions needed (and potentially their immediate parent) so that the function to be added + in MyLayer (`foobar`) can work correctly. + """ + if initial_dependencies is None and start_entity is not None: + initial_dependencies = dependency_mapping[start_entity] + if initial_checked_dependencies is None: + initial_checked_dependencies = set() + + dependency_queue = deque(initial_dependencies) + all_dependencies = set() + all_dependencies_with_parent = [] + checked_dependencies = set(initial_checked_dependencies) + parents = {initial_dep: start_entity for initial_dep in initial_dependencies} + while len(dependency_queue) > 0: + # Pick element to visit + current = dependency_queue.popleft() + if current not in checked_dependencies: + # Add the dependencies + all_dependencies.add(current) + all_dependencies_with_parent += [(current, parents[current])] + if current in dependency_mapping.keys(): + # Update dependency queue + dependency_queue.extend(dependency_mapping[current]) + parents.update({dep: current for dep in dependency_mapping[current]}) + # add visited node to the list + checked_dependencies.add(current) + + if not return_parent: + return all_dependencies + # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) + return all_dependencies_with_parent + + +# These top-level variables will always use the value in the `modular_xxx.py` file +ASSIGNMENTS_TO_KEEP = { + "_CHECKPOINT_FOR_DOC", +} + + +class ClassDependencyMapper(CSTVisitor): + """A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of + `global_names`. + """ + + def __init__( + self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None + ): + super().__init__() + self.class_name = class_name + self.dependencies = set() + self.global_names = global_names + self.objects_imported_from_modeling = ( + set() if objects_imported_from_modeling is None else objects_imported_from_modeling + ) + + def visit_Name(self, node): + if ( + node.value != self.class_name + and node.value in self.global_names + and node.value not in self.objects_imported_from_modeling + ): + self.dependencies.add(node.value) + + +def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set: + """Create immediate dependencies for a class node based on the `global_names`.""" + temp_module = cst.Module(body=[node]) + visitor = ClassDependencyMapper(node.name.value, global_names) + temp_module.visit(visitor) + return visitor.dependencies + + +def augmented_dependencies_for_class_node( + node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None +) -> set: + """Create augmented dependencies for a class node based on a `mapper`. + Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. + """ + temp_module = cst.Module(body=[node]) + visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling) + temp_module.visit(visitor) + return mapper.augment_dependencies(visitor.dependencies) + + +# All the potential file types to create +ALL_FILE_TYPES = ( + "modeling", + "configuration", + "tokenization", + "processing", + "image_processing", + "feature_extractor", +) + + +class ModuleMapper(CSTVisitor, ABC): + """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. + Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in + `self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`). + It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the + modeling files that will be visited. """ - Given the `class_name`, the `updated_node`'s call to super are unpacked. + + METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) + + def __init__(self, python_module: cst.Module): + # fmt: off + self.python_module: cst.Module = python_module # original cst.Module being visited + self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!) + self.imports = [] # stores all import statements + self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes + self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition) + self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes + self.current_function = None # this keeps track of the current module-scope function + self.current_assignment = None # this keeps track of the current module-scope assignment + # this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency + self.objects_imported_from_modeling = set() + # regex pattern joining every possible file type + self.match_patterns = "|".join(ALL_FILE_TYPES) + # fmt: on + + def visit_ImportFrom(self, node): + """This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have + `from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs + to be added (because it will be part of the imports)""" + import_module = self.python_module.code_for_node(node.module) + import_statement = "." * len(node.relative) + import_module + if re.search(rf"^\.({self.match_patterns})_.*", import_statement): + for imported_object in node.names: + # If an alias is present, we record it and not the original name + if imported_object.evaluated_alias is not None: + self.objects_imported_from_modeling.add(imported_object.evaluated_alias) + else: + self.objects_imported_from_modeling.add(imported_object.evaluated_name) + + def visit_SimpleStatementLine(self, node): + """ + Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements + are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. + """ + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + simple_top_level_assign_structure = m.SimpleStatementLine( + body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] + ) + if m.matches(parent_node, m.Module()): + if m.matches(node, simple_top_level_assign_structure): + left_hand_side = node.body[0].targets[0].target.value + self.current_assignment = left_hand_side + self.assignments[left_hand_side] = node + elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): + self.imports.append(node) + + def leave_SimpleStatementLine(self, node): + # No need to check for the parent here -> everytime we exit one, it should be None anyway independently of where the + # SimpleStatement is located + self.current_assignment = None + + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_function = node.name.value + self.functions[node.name.value] = node + + def leave_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_function = None + + def visit_If(self, node): + for stmt in node.body.body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): + self.imports.append(node) + + def visit_ClassDef(self, node: ClassDef) -> None: + """Record class nodes to create their dependencies at the end.""" + self.classes[node.name.value] = node + + def visit_Name(self, node: cst.Call): + """This is used to create a mapping from module-scope functions and assignments to objects used inside them.""" + if self.current_function is not None: + self.object_dependency_mapping[self.current_function].add(node.value) + if self.current_assignment is not None: + self.object_dependency_mapping[self.current_assignment].add(node.value) + + def leave_Module(self, node): + """When leaving the module, we store the position of each global scoped node to allow sorting the dependencies + based on their position in the code later. We use the PositionProvider metadata wrapper for this. + We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in + `self.global_nodes`. + """ + # assign all nodes + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + # now sort the class dependency_mapping based on the position of the nodes + self.start_lines = {} + for id, node in self.global_nodes.items(): + self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line + + # Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that + # are not part of the recorded objects (i.e. built-in variables, imports, etc) + global_objects = set(self.global_nodes.keys()) + for object_name, dependencies in self.object_dependency_mapping.items(): + self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} + + def _compute_recursive_object_dependencies(self) -> dict[str, set]: + """Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the + following file: + ``` + def foo(): + pass + + def bar(): + foo() + + def test(): + bar() + ``` + this visitor can only record immediate dependencies, i.e. it will record the following + `self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create + the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`. + """ + recursive_dependencies = {} + for object_name in self.object_dependency_mapping.keys(): + all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name) + recursive_dependencies[object_name] = all_dependencies + return recursive_dependencies + + def augment_dependencies(self, dependencies: set[str]) -> set[str]: + """For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and + **assignments** present in the `dependencies`. + """ + new_dependencies = dependencies.copy() + # Go through the set of dependencies + for dep in tuple(dependencies): + if dep in self.object_recursive_dependency_mapping.keys(): + new_dependencies.update(self.object_recursive_dependency_mapping[dep]) + return new_dependencies + + def compute_class_dependencies(self): + """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies.""" + self.class_dependency_mapping = {} + for class_name, class_node in self.classes.items(): + dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) + # Correctly augment class dependencies with all needed objects + self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies) + + @abstractmethod + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + raise NotImplementedError + + +class ModelFileMapper(ModuleMapper): + """A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file + in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. + For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes + care of correctly merging dependencies, then finalizes all dependency graph computations. + Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified. + For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies + of the modeling files as well. + """ + + def __init__(self, python_module: cst.Module): + super().__init__(python_module) + + def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]: + """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that + will be created based on the modular. + """ + relative_order = {} + idx = 0 + classes = sorted( + [dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x] + ) + # This is because for merged dependencies, we only have relative order in the other visited file, so we need + # to track dependency order relative to a given class + if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): + raise ValueError("Cannot correctly find the relative order of the dependencies.") + + remaining_dependencies = missing_dependencies.copy() + + # Start by tracking relative order class by class + for class_name in classes: + class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) + original_dependencies = [] + merged_dependencies = [] + # We need to differentiate between nodes that were already present (we can get relative order globally) and + # nodes that were merged (we can get relative order only relative to the class the dependencies relate to) + for class_dep in class_dependencies: + if class_dep in self.start_lines: + original_dependencies.append(class_dep) + else: + merged_dependencies.append(class_dep) + # Sort both list according to the order in their respective file + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) + + # Add all original node first, then merged ones + for dep in original_dependencies + merged_dependencies: + remaining_dependencies.remove(dep) + relative_order[dep] = idx + idx += 1 + # Add the class itself + remaining_dependencies.remove(class_name) + relative_order[class_name] = idx + idx += 1 + + # Now add what still remains + remaining_dependencies = tuple(remaining_dependencies) + original_dependencies = [] + merged_dependencies = [] + for dep in remaining_dependencies: + if dep in self.modular_file_start_lines: + merged_dependencies.append(dep) + else: + original_dependencies.append(dep) + # Sort both list according to the order in their respective file + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) + + # Add all original node first, then merged ones + for dep in original_dependencies + merged_dependencies: + relative_order[dep] = idx + idx += 1 + + return relative_order + + def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]): + """Update the global nodes and function dependency mapping with those from the modular file. + + Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies + instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). + """ + # Add/overwrite all needed function nodes and dependencies + self.functions.update(functions) + self.object_dependency_mapping.update( + {obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} + ) + + def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): + """Update the global nodes with the assignment from the modular file. + + Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it is + in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the + big docstrings. + """ + for assignment, node in assignments.items(): + if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments: + self.assignments[assignment] = node + if assignment in object_mapping: + self.object_dependency_mapping[assignment] = object_mapping[assignment] + + def _merge_classes(self, classes: dict[str, cst.CSTNode]): + """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and + are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined + classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we + do not add the new classes to `self.classes`, but only to `global_nodes`. + """ + # Add/overwrite all needed function nodes and dependencies + self.global_nodes.update( + { + name: node + for name, node in classes.items() + if name not in self.classes and name not in self.objects_imported_from_modeling + } + ) + + def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines): + """Merge classes, functions and assignments from the modular definitions into the current module file, + then record the relative order of all nodes. + Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the + merge with other files dependencies. + """ + self._merge_functions(functions, object_mapping) + self._merge_assignments(assignments, object_mapping) + self._merge_classes(classes) + self.modular_file_start_lines = start_lines + + # Correctly re-set the global nodes at this point + self.global_nodes.update(self.functions) + self.global_nodes.update(self.assignments) + # Create the global mapping of recursive dependencies for functions and assignments + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() + + @classmethod + def visit_and_merge_dependencies( + cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines + ) -> "ModelFileMapper": + wrapper = MetadataWrapper(module) + mapper = cls(module) + wrapper.visit(mapper) + # Merge dependencies + mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines) + # Create the class dependencies graph + mapper.compute_class_dependencies() + return mapper + + +def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str): + """ + Replace a class node which inherits from another modeling class. This function works in the following way: + - start from the base class node of the inherited class (a cst.Node) + - replace all methods of the base node with the methods defined in the child class + - append all new methods defined in the child class + - replace all calls to super() with the unravelled code | ```python | | ```python | class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): @@ -547,14 +856,15 @@ def replace_call_to_super( | self.post_init() | ``` """ - original_node = class_finder.classes[class_name] + all_bases = [k.value.value for k in class_node.bases] + + original_node = mapper.classes[renamed_super_class] original_methods = { - f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f + f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in original_node.body.body } updated_methods = { - f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f - for f in updated_node.body.body + f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body } end_meth = [] @@ -562,7 +872,7 @@ def replace_call_to_super( docstring_node = [] # Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict for func in original_node.body.body: - name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func) + name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None: new_params = updated_methods[name].params # Replace the method in the replacement class, preserving decorators @@ -573,19 +883,23 @@ def replace_call_to_super( new_params = new_params.with_changes( params=list(parent_params.values()), star_kwarg=func.params.star_kwarg ) + # Keep decorators in `modular_xxx.py` if any, else original decorators + new_decorators = ( + updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators + ) if not re.match( r"\ndef .*\(.*\):\n raise.*Error\(.*", - class_finder.python_module.code_for_node(updated_methods[name]), + mapper.python_module.code_for_node(updated_methods[name]), ): - func = func.with_changes(body=updated_methods[name].body, params=new_params) + func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators) else: continue if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): - target = class_finder.python_module.code_for_node(func.body[0].targets[0]) + target = mapper.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): - target = class_finder.python_module.code_for_node(func.body[0].target) + target = mapper.python_module.code_for_node(func.body[0].target) assign_targets[target] = func elif m.matches(func, DOCSTRING_NODE): docstring_node = [func] @@ -593,8 +907,8 @@ def replace_call_to_super( end_meth.append(func) # Port new methods that are defined only in modular-file and append at the end - for func in updated_node.body.body: - name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func) + for func in class_node.body.body: + name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class! # Extract the original docstring updated_docstring = func.body[0].value.value @@ -608,22 +922,28 @@ def replace_call_to_super( end_meth.append(func) if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): # TODO we only use single assign might cause issues - target = class_finder.python_module.code_for_node(func.body[0].targets[0]) + target = mapper.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): - target = class_finder.python_module.code_for_node(func.body[0].target) + target = mapper.python_module.code_for_node(func.body[0].target) assign_targets[target] = func end_meth = docstring_node + list(assign_targets.values()) + end_meth + # Replace the calls to `super()` with the unrolled code result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) temp_module = cst.Module(body=[result_node]) new_module = MetadataWrapper(temp_module) new_replacement_class = new_module.visit( - SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases) + SuperTransformer(temp_module, original_methods, updated_methods, all_bases) ) new_replacement_body = new_replacement_class.body[0].body # get the indented block - return original_node.with_changes(body=new_replacement_body) + # Use decorators redefined in `modular_xxx.py` if any + new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators + # Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`) + name = class_node.name + + return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name) TYPE_TO_FILE_TYPE = { @@ -632,498 +952,483 @@ def replace_call_to_super( "Processor": "processing", "ImageProcessor": "image_processing", "FeatureExtractor": "feature_extractor", + "ProcessorKwargs": "processing", + "ImagesKwargs": "processing", + "TextKwargs": "processing", } -def get_new_part(class_name, base_class): +def find_file_type(class_name: str) -> str: + """Based on a class name, find the file type corresponding to the class. + If the class name is `LlamaConfig` it will return `configuration`. + The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling` """ - When `MyClassNameAttention` inherits from `MistralAttention`, we need - to process the name to properly find dependencies. - - Here we take what is the same (Attention) and what is different - when finding the dependencies. - """ - common_suffix_len = 0 - for i in range(1, min(len(class_name), len(base_class)) + 1): - if class_name[-i] == base_class[-i]: - common_suffix_len += 1 - else: - break - - if common_suffix_len > 0: - new_part = class_name[:-common_suffix_len] + match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) + match = re.search(rf"({match_pattern})$", class_name) + if match: + file_type = TYPE_TO_FILE_TYPE[match.group(1)] else: - new_part = class_name + file_type = "modeling" + return file_type - # Convert the remaining new part to snake_case - snake_case = re.sub(r"(? 0: + new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) + imports_to_keep.append(new_node) - def foobar(): - bar() - foo2() - class MyLayer(SomeOtherModelLayer): - def forward(...): - foobar() - ``` - and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: - ``` - dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} - find_all_dependencies('foobar', dependency_mapping) - >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] - ``` - That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can - work correctly. +def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: + """Get all the imports needed in the `body`, from the list of `all_imports`. + `body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`. + Note: we need to use `isinstance` on scope assignements, m.matches apparently does not work here yet! """ - all_dependencies = deque(dependency_mapping[function]) - all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]] - checked_dependencies = set(function) - while len(all_dependencies) > 0: - # Pick element to visit - parent = all_dependencies.popleft() - if parent not in checked_dependencies: - # Update dependencies - all_dependencies.extend(dependency_mapping[parent]) - all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]] - # add visited node to the list - checked_dependencies.add(parent) - - # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) - return all_dependencies_with_parent - - -class PostModularConverterCleaner(CSTTransformer): - """Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due - to dependency mapping, even if code parts with those functions/classes were overwritten)""" - - METADATA_DEPENDENCIES = (ParentNodeProvider,) - - def __init__(self, added_dependencies: set): - super().__init__() - self.top_level_functions_or_classes = {} - self.all_used_functions_or_classes = set() - self.added_dependencies = added_dependencies - - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.top_level_functions_or_classes[node.name.value] = node - - def visit_ClassDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.top_level_functions_or_classes[node.name.value] = node - - def visit_Name(self, node: cst.Name): - """This is used to find any mention of a top-level function or class except its own definition. - It will contain other names as well, but those will not be used. This is the most general way to do it - since mentions may appear in a lot of different contexts (apart from simple Call to the function/class). - e.g. Attention classes are only mentionned by their name in a dict assignment. - """ - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - - if not ( - (m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value) - or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value) - ): - self.all_used_functions_or_classes.add(node.value) - - def leave_Module(self, original_node: cst.Module, node): - # Find any class/function that was mistakenly added as part of the dependencies and remove it - unused = self.added_dependencies - self.all_used_functions_or_classes - nodes_to_remove = [ - self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes - ] - new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove] - # Return a new module with the updated body - return node.with_changes(body=new_body) + new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] + wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) + scopes = set(wrapper.resolve(ScopeProvider).values()) + unused_imports = set() + import_ref_count = {} + for scope in scopes: + for assignment in scope.assignments: + node = assignment.node + if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): + ref_count = len(assignment.references) + name = assignment.name + # Similar imports may be redefined, and only used between their 1st and 2nd definition + # so if we already have a ref count > 0, the imports is actually used + if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys(): + unused_imports.add(name) + import_ref_count[name] = ref_count + + imports_to_keep = [] + for node in all_imports: + if m.matches(node, m.If()): # handle safe imports + new_statements = [] + for stmt_node in node.body.body: + append_new_import_node(stmt_node, unused_imports, new_statements) + if len(new_statements) > 0: + new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) + imports_to_keep.append(new_node) + else: + append_new_import_node(node, unused_imports, imports_to_keep) + + protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())] + usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())] + # If the same import is both protected and unprotected, only keep the protected one + for protected_node in protected_import_nodes: + for stmt_node in protected_node.body.body: + usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]] + + # Protected imports always appear at the end of all imports + return usual_import_nodes + protected_import_nodes + + +def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]: + """Split the `__all__` assignment found in the modular between each corresponding files.""" + all_all_per_file = {} + assign_node = node.body[0] + if isinstance(assign_node.value, cst.List): + # Extract the elements from the list + all_all_to_add = defaultdict(list) + for element in assign_node.value.elements: + if isinstance(element.value, cst.SimpleString): + # Remove quotes and add the string to the elements list + class_name = element.value.value + file = find_file_type(element.value.evaluated_value) + all_all_to_add[file] += [class_name] + for file, new_alls in all_all_to_add.items(): + new_node = assign_node.with_changes( + value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) + ) + all_all_per_file[file] = node.with_changes(body=[new_node]) + return all_all_per_file -class ModularConverterTransformer(CSTTransformer): - METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) +class ModularFileMapper(ModuleMapper): + """This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency, + then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies. + Calling the method `create_modules()` after visit will create all modules based on this modular file. + """ def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None): - super().__init__() - self.model_name = ( - new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3` - ) + super().__init__(python_module) + # fmt: off + self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3` self.given_old_name = given_old_name self.given_new_name = given_new_name - # fmt: off - self.python_module = python_module # we store the original module to use `code_for_node` - self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module - self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"} - self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama" - self.inserted_deps = [] # nodes inserted via super dependency - self.all_imports = [] # just stores all of the imports - self.all_safe_imports = [] # stores the import under simple statements - self.global_scope_index = 0 + + self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} + self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} + + self.all_all_to_add = {} # fmt: on - self.files = { # mapping for different component bodies - "modeling": {}, - "configuration": {}, - "tokenization": {}, - "processing": {}, - "image_processing": {}, - "feature_extractor": {}, - } - self.match_patterns = "|".join(self.files.keys()) - self.all_definitions = {} - self.class_to_file_type = {} - self.current_class = None # keep track of current top-level class during visit - self.current_top_level_function = None # keep track of current top-level function during visit - # Mapping from top-level functions to classes using them - self.function_call_class_mapping = defaultdict(lambda: set()) - # Mapping from top-level functions to other top-level functions dependencies - self.function_call_dependency_mapping = defaultdict(lambda: set()) - self.added_dependencies = set() def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - """When visiting imports from `transformers.models.xxx` we need to: - 1. Get the original source code - 2. Parse it into an AST Tree - 3. Add this import to `self.transformers_imports` as visited to not parse it twice + """When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, + and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. """ - import_statement = self.python_module.code_for_node(node.module) + import_module = self.python_module.code_for_node(node.module) + import_statement = "." * len(node.relative) + import_module + if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): + return if m.matches(node.module, m.Attribute()): for imported_ in node.names: - _import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement) + _import = re.search( + rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement + ) if _import: - source = _import.groups()[0] + source = _import.group(1) if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): raise ValueError( f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" ) - if import_statement not in self.transformers_imports: - if "models" not in import_statement: - import_statement = "models." + import_statement - if "transformers" not in import_statement: - import_statement = "transformers." + import_statement - source_code = get_module_source_from_name(import_statement) + if import_module not in self.model_specific_modules: + if "models" not in import_module: + import_module = "models." + import_module + if "transformers" not in import_module: + import_module = "transformers." + import_module + source_code = get_module_source_from_name(import_module) tree = cst.parse_module(source_code) - self.transformers_imports[import_statement] = tree - imported_class = self.python_module.code_for_node(imported_.name) - self.imported_mapping[imported_class] = import_statement + self.model_specific_modules[import_module] = tree + imported_object = self.python_module.code_for_node(imported_.name) + self.model_specific_imported_objects[imported_object] = import_module if m.matches(node.module, m.Name()): - if "transformers" == import_statement: + if "transformers" == import_module: raise ValueError( - f"You are importing from {import_statement} directly using global imports. Import from the correct local path" + f"You are importing from {import_module} directly using global imports. Import from the correct local path" ) - def leave_SimpleStatementLine(self, original_node, updated_node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + def visit_SimpleStatementLine(self, node): + """If we visit an import statement not previously visited, record it. If we visit a module-scope assignment, + simply record it or, if it is `__all__`, split it between files where we should dispatch it. + """ + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + simple_top_level_assign_structure = m.SimpleStatementLine( + body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] + ) if m.matches(parent_node, m.Module()): - if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): - if updated_node not in self.all_imports: - self.all_imports.append(updated_node) - return updated_node - elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])): - full_statement = self.python_module.code_for_node(updated_node.body[0].module) - if re.search( - rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement - ): # OR MATCH ..llama.modeling_llama - return cst.RemoveFromParent() - if updated_node not in self.all_imports: - self.all_imports.append(updated_node) - return updated_node - elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])): - if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys(): - file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value] - self.files[file_][original_node.body[0].targets[0].target.value] = { - "node": original_node, - "insert_idx": self.global_scope_index, - } - self.global_scope_index += 100 - return updated_node - - def visit_ClassDef(self, node: cst.ClassDef): - """Used to keep track of current class""" - self.current_class = node.name.value + if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): + self.imports.append(node) + elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): + import_module = self.python_module.code_for_node(node.body[0].module) + import_statement = "." * len(node.body[0].relative) + import_module + if not ( + re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) + and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR) + ): + self.imports.append(node) + elif m.matches(node, simple_top_level_assign_structure): + assigned_variable = node.body[0].targets[0].target.value + # __all__ is treated differently and not added to general assignments + if assigned_variable == "__all__": + self.all_all_to_add = split_all_assignment(node) + else: + self.assignments[assigned_variable] = node - def leave_ClassDef(self, original_node, updated_node): + def leave_Module(self, node): + """When we leave the modular file, we do the following in order: + 1. compute the nested (recursive) function and assignment dependencies + 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update + its dependency graph with the new function and assignment definitions found in the modular + 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) """ - 1. Filter the `base` classes of this class - If they are from `transformers.models.xx` then: - - take the AST tree of the module it comes from and parse it with a `ClassFinder`. - - rename all every instance of `old_name` (llama) to `new_name` (gemma) - 2. We insert the modules which the inherited base depends on. This has to be done in - the order of the dependencies. If on is already in the new_body (because it's defined in the diff file) - then we remove it from the new body to add it again in the correct order. - 3. Replace the calls to `super().xxxx` merging parent code + # Takes care of finalizing our visit + super().leave_Module(node) + + # 1. compute the nested (recursive) function and assignment dependencies + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() + + # 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies + self.visited_modules = {} + self.renamers = {} + for file, module in self.model_specific_modules.items(): + file_model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", file).groups()[0] + renamer = ReplaceNameTransformer( + file_model_name, self.model_name, self.given_old_name, self.given_new_name + ) + renamed_module = module.visit(renamer) + self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( + renamed_module, + self.classes, + self.functions, + self.assignments, + self.object_dependency_mapping, + self.start_lines, + ) + # We record it so that we can rename classes later the exact same way + self.renamers[file] = renamer + + # 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # definitions found in the visited files + self.merge_model_specific_imports(self.visited_modules) + + # We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later + # Note that we may visit several of the same file types, thus we save them per file type, not file + self.imported_objects_per_file = defaultdict(set) + for file, mapper in self.visited_modules.items(): + file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1) + self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling) + + def merge_model_specific_imports(self, visited_modules): + """Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, + based on the visited files.""" + self.start_lines_file_mapping = {} + self.added_objects_file_mapping = {} + for object_name, file in self.model_specific_imported_objects.items(): + visited_module = visited_modules[file] + self.start_lines_file_mapping[file] = visited_module.start_lines + # Add functions and their dependencies + if object_name in visited_module.functions and object_name not in self.functions: + self.functions[object_name] = visited_module.functions[object_name] + self.added_objects_file_mapping[object_name] = file + dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + if dependencies is not None: + self.object_recursive_dependency_mapping[object_name] = dependencies + for dep in dependencies: + if dep not in self.global_nodes: + self.added_objects_file_mapping[dep] = file + self.functions[dep] = visited_module.global_nodes[dep] + + # Add assignments and their dependencies + elif object_name in visited_module.assignments and object_name not in self.assignments: + self.assignments[object_name] = visited_module.assignments[object_name] + self.added_objects_file_mapping[object_name] = file + dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + if dependencies is not None: + self.object_recursive_dependency_mapping[object_name] = dependencies + for dep in dependencies: + if dep not in self.global_nodes: + self.added_objects_file_mapping[dep] = file + self.assignments[dep] = visited_module.global_nodes[dep] + + # Do not forget to re-assign all nodes after the merge + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that + will be created based on the modular. """ - class_name = original_node.name.value - bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping] - all_bases = [k.value.value for k in original_node.bases] - self.global_scope_index += 100 - for super_class in bases: - if super_class not in self.imported_mapping: - raise ImportError( - f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}" - ) - - super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree - model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", super_file_name) - if model_name: - model_name = model_name.groups()[0] + relative_order = {} + idx = 0 + + original_dependencies = [] + other_files_dependencies = defaultdict(list) + for dep in tuple(missing_dependencies): + if dep in self.added_objects_file_mapping: + file = self.added_objects_file_mapping[dep] + other_files_dependencies[file].append(dep) else: - raise ValueError( - f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name" - ) - file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] - visited_module = self.visited_module - if super_file_name not in visited_module: # only extract classes once - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - self.model_name, - self.given_old_name, - self.given_new_name, - ) - visited_module[super_file_name] = class_finder - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - else: # we are re-using the previously parsed data - class_finder = visited_module[super_file_name] - - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - # so, maybe standard renaming did not work (the class name is different) - # we try with another renaming pattern - potential_given_name = get_new_part(class_name, super_class) - del visited_module[super_file_name] - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - potential_given_name, - self.model_name, - potential_given_name, - ) - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - # last recourse, if the suffix of the new class is different from the one of the super class - # e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection - # we try with another renaming pattern - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - self.model_name, - self.given_old_name, - self.given_new_name, - super_class, - class_name, - ) - visited_module[super_file_name] = class_finder - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - raise ValueError( - f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})" - f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}." - f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`" - ) + original_dependencies.append(dep) + # Sort all lists according to the order in their respective file + all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + for file, dependencies in other_files_dependencies.items(): + sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) + all_dependencies += sorted_dependencies + + # Add all original node first, then merged ones (one file at a time) + for dep in all_dependencies: + relative_order[dep] = idx + idx += 1 + + return relative_order + + +def check_dependencies_and_create_import_node( + file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str +) -> tuple[set[str], dict[str, cst.CSTNode]]: + """Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case, + we need to remove it from the dependencies, and create a new import to it instead. + This scenario may appear in the following case: + If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py` + (e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as + part of the standard dependency graph (because we never encountered an import towards this new class in any file). + For example imagine the following `modular.py`: + ``` + from ..llama.modeling_llama import LlamaModel - list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) - start_insert_idx = self.global_scope_index - file_to_update = self.files[file_type] - is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n" - for dependency, _ in list_dependencies: - # we can write to the correct body, using the source of the parent class - node = class_finder.global_nodes.get(dependency, None) - if node is not None: - if dependency not in file_to_update: - node = self.all_definitions.pop(dependency, node) - start_insert_idx -= 1 - file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} - self.added_dependencies.add(dependency) - elif dependency not in self.inserted_deps: - # make sure the node is written after its dependencies - start_insert_idx = file_to_update[dependency]["insert_idx"] - 1 - if ( - dependency in file_to_update.keys() - and dependency in class_finder.first_lvl_dependency_mapping[class_name] - ): - # If dependency is defined, but not used, raise error - calls = m.findall(original_node, m.Call(func=m.Name(dependency))) - if not calls and not is_empty_node and dependency not in all_bases: - raise ValueError( - f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used - when you define `{class_name}`, as it is one of it's direct dependencies. Make sure - you use it in the `__init__` function.""" - ) - self.inserted_deps.append(dependency) - - if len(list_dependencies) > 0: - updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases) - - # Now, if a class was defined without parents, we look for the name - match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) - match = re.search(rf"({match_pattern})$", class_name) - if match: - key = TYPE_TO_FILE_TYPE[match.group(1)] - self.class_to_file_type[class_name] = key - self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} - else: - self.class_to_file_type[class_name] = "modeling" - self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} + class NewNameTextConfig(PretrainedConfig): + ... - self.current_class = None - return updated_node + class NewNameConfig(PretrainedConfig): + ... - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.current_top_level_function = node.name.value + class NewNameModel(LlamaModel): + config = NewNameConfig() + text_config = NewNameTextConfig() + ... + ``` + then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as + `configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no + knowledge of `NewNameTextConfig`. + """ + class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())} + corrected_dependencies = new_dependencies.copy() + new_imports = {} + for class_name in class_dependencies: + class_file_type = find_file_type(class_name) + # In this case, we need to remove it from the dependencies and create a new import instead + if class_file_type != file_type: + corrected_dependencies.remove(class_name) + import_statement = f"from .{class_file_type}_{new_name} import {class_name}" + new_imports[class_name] = cst.parse_statement(import_statement) + + return corrected_dependencies, new_imports + + +def get_class_node_and_dependencies( + modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] +) -> tuple[dict, str, dict]: + """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new + class node based on the inherited classes if needed. Also returns any new imports of a new class defined in + the modular that we nay need. + """ + bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects] + if len(bases) > 1: + raise ValueError( + f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}." + ) - def leave_FunctionDef(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - self.all_definitions[node.name.value] = node - return node - - def visit_Assign(self, node: cst.Assign) -> None: - # Check if the assignment target is '__all__' - if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__": - if isinstance(node.value, cst.List): - # Extract the elements from the list - all_all_to_add = defaultdict(list) - for elt in node.value.elements: - if isinstance(elt.value, cst.SimpleString): - # Remove quotes and add the string to the elements list - class_name = elt.value.value - file = self.class_to_file_type[ - elt.value.evaluated_value - ] # evaluated value give the content of the string - all_all_to_add[file] += [class_name] - for f_type, new_alls in all_all_to_add.items(): - updated_node = node.with_changes( - value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) - ) - self.files[f_type][class_name] = { - "insert_idx": self.global_scope_index + 100, - "node": updated_node, - } - - def leave_If(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - full_statement = self.python_module.code_for_node(original_node.test) - if re.search(r"[\s\S]*is_.*available", full_statement): - self.all_safe_imports.append(node) - elif full_statement not in self.all_imports: - logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") - return node - - def visit_Call(self, node: cst.Call): - """This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them. - Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, - add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible.""" - # Only map function calls if we're inside a class (i.e., current_class is set) - if self.current_class is not None: - # Simple function calls such as foo() - if isinstance(node.func, cst.Name): - self.function_call_class_mapping[node.func.value].add(self.current_class) - elif self.current_top_level_function is not None: - # Simple function calls such as foo() - if isinstance(node.func, cst.Name): - self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value) - - def _maybe_add_function_to_body( - self, - top_level_function: str, - body: dict, - function_node: cst.FunctionDef, - matching_callers: Optional[set] = None, - parent: Optional[str] = None, - ) -> bool: - """Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers` - is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return - `True`. Return `False` otherwise. - """ - if matching_callers is None and parent is None: - raise ValueError("Cannot add function if both the parent and the matching callers are None.") - if matching_callers is None: - matching_callers = {parent} - if len(matching_callers) > 0 and top_level_function not in body.keys(): - # Add the function just before the first class using it - new_idx = min([body[element]["insert_idx"] for element in matching_callers]) - # Reorder the elements - for element in body.keys(): - if body[element]["insert_idx"] >= new_idx: - body[element]["insert_idx"] += 1 - # Assign new element to body (after changing the count to avoid messing it) - body[top_level_function] = {"insert_idx": new_idx, "node": function_node} - return True - return False - - def _recursively_add_all_new_needed_functions_in_files(self): - """For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in - the different files, and add them to the file if it is the case (also recursively adding all other functions that - may be needed in that function body).""" - # At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py` - for top_level_function, function_node in self.all_definitions.items(): - calling_entities = self.function_call_class_mapping[top_level_function] - # The function may be needed in different files, we need to iterate on them - for file, body in self.files.items(): - file_elements = set(body.keys()) - # If the intersection is not null, top_level_func must be added to file - matching_callers = calling_entities & file_elements - added = self._maybe_add_function_to_body(top_level_function, body, function_node, matching_callers) - # If the function was added, we need to recursively add all its dependencies - if added: - for dependency, parent in find_all_dependencies( - top_level_function, self.function_call_dependency_mapping - ): - self._maybe_add_function_to_body( - dependency, body, self.all_definitions[dependency], parent=parent - ) + file_type = find_file_type(class_name) + file_to_update = files[file_type] + model_name = modular_mapper.model_name - def leave_Module(self, original_node: cst.Module, node): - imports = {self.python_module.code_for_node(k): k for k in self.all_imports} - dependency_imports = {file_type: imports.copy() for file_type in self.files} - for super_file_name, visiter in self.visited_module.items(): - file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] - dependency_imports[file_type].update( - {self.python_module.code_for_node(k): k for k in visiter.imports.values()} - ) + # This is used to avoid adding objects to the dependencies graph if they will be imported already + imported_objects = modular_mapper.imported_objects_per_file[file_type] + + # We need to replace the class node with the transformers (modeling file) super class node + if len(bases) == 1: + super_class = bases[0] + super_file_name = modular_mapper.model_specific_imported_objects[super_class] + + # Get the mapper corresponding to the inherited class + mapper = modular_mapper.visited_modules[super_file_name] + # Rename the super class according to the exact same rule we used when renaming the whole module + renamer = modular_mapper.renamers[super_file_name] + renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name) + renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name) + + # Create the new class node + updated_node = replace_class_node(mapper, node, renamed_super_class) + + # Grab all immediate dependencies of the new node + new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) + + # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove + # it from the dependencies, and add a new import of it instead + new_node_dependencies, new_imports = check_dependencies_and_create_import_node( + file_type, new_node_dependencies, mapper, model_name + ) + + # The node was modified -> look for all recursive dependencies of the new node + all_dependencies_to_add = find_all_dependencies( + dependency_mapping=mapper.class_dependency_mapping, + initial_dependencies=new_node_dependencies, + initial_checked_dependencies=set(file_to_update.keys()), + ) + + relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) + nodes_to_add = { + dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add + } + + # No transformers (modeling file) super class, just check functions and assignments dependencies + else: + updated_node = node + # The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not + # already defined (which would mean a weird order of the code in the modular...), they will be in the future + all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects) + + # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove + # it from the dependencies, and add a new import of it instead + all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node( + file_type, all_dependencies_to_add, modular_mapper, model_name + ) + + relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) + nodes_to_add = { + dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep]) + for dep in all_dependencies_to_add + if dep not in file_to_update.keys() + } + + # Add the class node itself to the nodes to add + class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 + nodes_to_add[class_name] = (class_idx, updated_node) + + return nodes_to_add, file_type, new_imports + + +def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: + """Create all the new modules based on visiting the modular file. It replaces all classes as necesary.""" + files = defaultdict(dict) + current_file_indices = defaultdict(lambda: 0) + + # For each class defined in modular, potentially replace the node and add it with its dependencies + for class_name, node in modular_mapper.classes.items(): + nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files) + + # Add the new potential new imports that we may need to the `modular_mapper` variable + modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys()) + modular_mapper.imports.extend(list(new_imports.values())) + + # Sort the nodes according to their relative order + nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) + # Write all nodes to file + for dependency, (_, node) in nodes_to_add: + # This is used to keep certain variables at the beginning of the file + try: + # The -1000 is arbitrary -> just keep it bigger than the list + idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) + except ValueError: + idx = current_file_indices[file_type] + current_file_indices[file_type] += 1 + files[file_type][dependency] = {"insert_idx": idx, "node": node} + + # Add the __all__ statement to files at the end + for file_type, node in modular_mapper.all_all_to_add.items(): + idx = current_file_indices[file_type] + files[file_type]["__all__"] = {"insert_idx": idx, "node": node} + + # Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because + # they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc) + all_imports = modular_mapper.imports.copy() + all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports} + for file, mapper in modular_mapper.visited_modules.items(): + new_imports = [ + node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code + ] + new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} + all_imports.extend(new_imports) + all_imports_code.update(new_imports_code) - # Check if any new top-level function from the `modular_xxx.py` should be added to the different files - # (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file). - self._recursively_add_all_new_needed_functions_in_files() + # Find the correct imports, and write the new modules + for file, body in files.items(): + new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] + needed_imports = get_needed_imports(body, all_imports) + full_module = needed_imports + new_body + new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header) + files[file] = new_module - for file, body in self.files.items(): - new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] - if len(new_body) > 0: - if file in dependency_imports.keys(): - new_body = list(dependency_imports[file].values()) + new_body - new_module = cst.Module(body=[*new_body], header=node.header) - # Final cleanup - new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies)) - self.files[file] = new_module - return node + return files def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None): @@ -1137,10 +1442,10 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, module = cst.parse_module(code) wrapper = MetadataWrapper(module) if cst_transformers is None: - cst_transformers = ModularConverterTransformer(module, model_name, old_model_name, new_model_name) + cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name) wrapper.visit(cst_transformers) - for file, node in cst_transformers.files.items(): - if node != {}: + for file, module in create_modules(cst_transformers).items(): + if module != {}: # Get relative path starting from src/transformers/ relative_path = re.search( r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/") @@ -1149,7 +1454,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, header = AUTO_GENERATED_MESSAGE.format( relative_path=relative_path, short_name=os.path.basename(relative_path) ) - ruffed_code = run_ruff(header + node.code, True) + ruffed_code = run_ruff(header + module.code, True) formatted_code = run_ruff(ruffed_code, False) output[file] = [formatted_code, ruffed_code] return output @@ -1180,7 +1485,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/roberta/modular_roberta.py"], + default=["src/transformers/models/gemma/modular_gemma.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) @@ -1197,6 +1502,7 @@ def save_modeling_file(modular_file, converted_file): args = parser.parse_args() if args.files_to_parse == ["all"]: args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) + args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True) for file_name in find_priority_list(args.files_to_parse): print(f"Converting {file_name} to a single model single file format")