-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix cache handling and causality for cross attention
- Loading branch information
Showing
4 changed files
with
2,469 additions
and
16 deletions.
There are no files selected for viewing
236 changes: 236 additions & 0 deletions
236
src/transformers/models/moonshine/configuration_moonshine.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 | ||
# This file was automatically generated from src/transformers/models/moonshine/modular_moonshine.py. | ||
# Do NOT edit this file manually as any edits will be overwritten by the generation of | ||
# the file from the modular. If any change should be done, please apply the change to the | ||
# modular_moonshine.py file directly. One of our CI enforces this. | ||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 | ||
|
||
from ...configuration_utils import PretrainedConfig | ||
|
||
|
||
class MoonshineConfig(PretrainedConfig): | ||
r""" | ||
This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine | ||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | ||
defaults will yield a similar configuration to that of the Moonshine | ||
[UsefulSensors/moonshine](https://huggingface.co/UsefulSensors/moonshine). | ||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | ||
documentation from [`PretrainedConfig`] for more information. | ||
Args: | ||
vocab_size (`int`, *optional*, defaults to 32768): | ||
Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the | ||
`inputs_ids` passed when calling [`MoonshineModel`]. | ||
hidden_size (`int`, *optional*, defaults to 288): | ||
Dimension of the hidden representations. | ||
intermediate_size (`int`, *optional*): | ||
Dimension of the MLP representations. | ||
num_hidden_layers (`int`, *optional*, defaults to 6): | ||
Number of hidden layers in the Transformer encoder and decoder. | ||
num_attention_heads (`int`, *optional*, defaults to 8): | ||
Number of attention heads for each attention layer in the Transformer encoder and decoder. | ||
num_key_value_heads (`int`, *optional*): | ||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If | ||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if | ||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When | ||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed | ||
by meanpooling all the original heads within that group. For more details checkout [this | ||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to | ||
`num_attention_heads`. | ||
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): | ||
The non-linear activation function (function or string) in the encoder. | ||
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): | ||
The non-linear activation function (function or string) in the decoder. | ||
max_position_embeddings (`int`, *optional*, defaults to 2048): | ||
The maximum sequence length that this model might ever be used with. TODO: check this | ||
initializer_range (`float`, *optional*, defaults to 0.02): | ||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | ||
layer_norm_eps (`float`, *optional*, defaults to 1e-5): | ||
The epsilon used by the layer normalization layers. | ||
decoder_start_token_id (`int`, *optional*, defaults to 1): | ||
Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids` | ||
are provided to the `generate` function. It is used to guide the model`s generation process depending on | ||
the task. | ||
use_cache (`bool`, *optional*, defaults to `True`): | ||
Whether or not the model should return the last key/values attentions (not used by all models). | ||
is_encoder_decoder (`bool`, *optional*, defaults to `True`): | ||
Whether the model is used as an encoder/decoder or not. | ||
rope_theta (`float`, *optional*, defaults to 10000.0): | ||
The base period of the RoPE embeddings. TODO: check this | ||
partial_rotary_factor (`float`, *optional*, defaults to 0.5): | ||
Percentage of the query and keys which will have rotary embedding. TODO: check this | ||
ff_mult (`int`, *optional*, defaults to 4): | ||
Factor by which to scale the intermediate size. | ||
attention_bias (`bool`, *optional*, defaults to `False`): | ||
Whether to use a bias in the query, key, value and output projection layers during self-attention. | ||
attention_dropout (`float`, *optional*, defaults to 0.0): | ||
The dropout ratio for the attention probabilities. | ||
qk_layernorm (`bool`, *optional*, defaults to `False`): | ||
Whether or not to normalize the Queries and Keys after projecting the hidden states. | ||
rope_scaling (`Dict`, *optional*): | ||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type | ||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value | ||
accordingly. | ||
Expected contents: | ||
`rope_type` (`str`): | ||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', | ||
'llama3'], with 'default' being the original RoPE implementation. | ||
`factor` (`float`, *optional*): | ||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In | ||
most scaling types, a `factor` of x will enable the model to handle sequences of length x * | ||
original maximum pre-trained length. | ||
`original_max_position_embeddings` (`int`, *optional*): | ||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during | ||
pretraining. | ||
`attention_factor` (`float`, *optional*): | ||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention | ||
computation. If unspecified, it defaults to value recommended by the implementation, using the | ||
`factor` field to infer the suggested value. | ||
`beta_fast` (`float`, *optional*): | ||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear | ||
ramp function. If unspecified, it defaults to 32. | ||
`beta_slow` (`float`, *optional*): | ||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear | ||
ramp function. If unspecified, it defaults to 1. | ||
`short_factor` (`List[float]`, *optional*): | ||
Only used with 'longrope'. The scaling factor to be applied to short contexts (< | ||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden | ||
size divided by the number of attention heads divided by 2 | ||
`long_factor` (`List[float]`, *optional*): | ||
Only used with 'longrope'. The scaling factor to be applied to long contexts (< | ||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden | ||
size divided by the number of attention heads divided by 2 | ||
`low_freq_factor` (`float`, *optional*): | ||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE | ||
`high_freq_factor` (`float`, *optional*): | ||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE | ||
bos_token_id (`int`, *optional*, defaults to 1): | ||
Denotes beginning of sequences token id. | ||
eos_token_id (`int`, *optional*, defaults to 2): | ||
Denotes end of sequences token id. | ||
apply_spec_augment (`bool`, *optional*, defaults to `False`): | ||
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see | ||
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech | ||
Recognition](https://arxiv.org/abs/1904.08779). | ||
mask_time_prob (`float`, *optional*, defaults to 0.05): | ||
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking | ||
procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If | ||
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be | ||
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the | ||
actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`. | ||
mask_time_length (`int`, *optional*, defaults to 10): | ||
Length of vector span along the time axis. | ||
mask_time_min_masks (`int`, *optional*, defaults to 2),: | ||
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, | ||
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < | ||
mask_time_min_masks'' | ||
mask_feature_prob (`float`, *optional*, defaults to 0.0): | ||
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The | ||
masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over | ||
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector | ||
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap | ||
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is | ||
True`. | ||
mask_feature_length (`int`, *optional*, defaults to 10): | ||
Length of vector span along the feature axis. | ||
mask_feature_min_masks (`int`, *optional*, defaults to 0),: | ||
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time | ||
step, irrespectively of `mask_feature_prob`. Only relevant if | ||
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`. | ||
Example: | ||
```python | ||
>>> from transformers import MoonshineModel, MoonshineConfig | ||
>>> # Initializing a Moonshine style configuration | ||
>>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine") | ||
>>> # Initializing a model from the configuration | ||
>>> model = MoonshineModel(configuration) | ||
>>> # Accessing the model configuration | ||
>>> configuration = model.config | ||
```""" | ||
|
||
model_type = "moonshine" | ||
keys_to_ignore_at_inference = ["past_key_values"] | ||
|
||
def __init__( | ||
self, | ||
vocab_size=32768, | ||
hidden_size=288, | ||
intermediate_size=None, | ||
num_hidden_layers=6, | ||
num_attention_heads=8, | ||
num_key_value_heads=None, | ||
encoder_hidden_act="gelu", | ||
decoder_hidden_act="silu", | ||
max_position_embeddings=2048, | ||
initializer_range=0.02, | ||
layer_norm_eps=1e-5, | ||
decoder_start_token_id=1, | ||
use_cache=True, | ||
is_encoder_decoder=True, | ||
rope_theta=10000.0, | ||
partial_rotary_factor=0.5, | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
qk_layernorm=False, | ||
rope_scaling=None, | ||
ff_mult=4, | ||
bos_token_id=1, | ||
eos_token_id=2, | ||
apply_spec_augment=False, | ||
mask_time_prob=0.05, | ||
mask_time_length=10, | ||
mask_time_min_masks=2, | ||
mask_feature_prob=0.0, | ||
mask_feature_length=10, | ||
mask_feature_min_masks=0, | ||
**kwargs, | ||
): | ||
self.vocab_size = vocab_size | ||
self.hidden_size = hidden_size | ||
self.intermediate_size = intermediate_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_attention_heads = num_attention_heads | ||
|
||
if num_key_value_heads is None: | ||
num_key_value_heads = num_attention_heads | ||
|
||
self.num_key_value_heads = num_key_value_heads | ||
self.encoder_hidden_act = encoder_hidden_act | ||
self.decoder_hidden_act = decoder_hidden_act | ||
self.max_position_embeddings = max_position_embeddings | ||
self.initializer_range = initializer_range | ||
self.layer_norm_eps = layer_norm_eps | ||
self.decoder_start_token_id = decoder_start_token_id | ||
self.use_cache = use_cache | ||
self.is_encoder_decoder = is_encoder_decoder | ||
self.rope_theta = rope_theta | ||
self.partial_rotary_factor = partial_rotary_factor | ||
|
||
self.attention_bias = attention_bias | ||
self.attention_dropout = attention_dropout | ||
self.qk_layernorm = qk_layernorm | ||
self.rope_scaling = rope_scaling | ||
self.ff_mult = ff_mult | ||
|
||
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 | ||
self.apply_spec_augment = apply_spec_augment | ||
self.mask_time_prob = mask_time_prob | ||
self.mask_time_length = mask_time_length | ||
self.mask_time_min_masks = mask_time_min_masks | ||
self.mask_feature_prob = mask_feature_prob | ||
self.mask_feature_length = mask_feature_length | ||
self.mask_feature_min_masks = mask_feature_min_masks | ||
|
||
super().__init__( | ||
bos_token_id=bos_token_id, | ||
eos_token_id=eos_token_id, | ||
is_encoder_decoder=is_encoder_decoder, | ||
decoder_start_token_id=decoder_start_token_id, | ||
**kwargs, | ||
) |
Oops, something went wrong.