From 9613933b022ddbf085e2c593ed4ceea4c734179a Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 19 Dec 2024 03:18:17 +0800 Subject: [PATCH] Add the Bamba Model (#34982) * initial commit for PR Co-authored-by: Gabe Goodhart * rename dynamic cache Signed-off-by: Yu Chin Fabian Lim * add more unit tests Signed-off-by: Yu Chin Fabian Lim * add integration test Signed-off-by: Yu Chin Fabian Lim * add integration test Signed-off-by: Yu Chin Fabian Lim * Add modular bamba file * Remove trainer changes from unrelated PR * Modify modular and cofig to get model running * Fix some CI errors and beam search * Fix a plethora of bugs from CI/docs/etc * Add bamba to models with special caches * Updat to newer mamba PR for mamba sublayer * fix test_left_padding_compatibility Signed-off-by: Yu Chin Fabian Lim * fix style Signed-off-by: Yu Chin Fabian Lim * fix remaining tests Signed-off-by: Yu Chin Fabian Lim * missed this test Signed-off-by: Yu Chin Fabian Lim * ran make style Signed-off-by: Yu Chin Fabian Lim * move slow tag to integration obj Signed-off-by: Yu Chin Fabian Lim * make style Signed-off-by: Yu Chin Fabian Lim * address comments Signed-off-by: Yu Chin Fabian Lim * fix modular Signed-off-by: Yu Chin Fabian Lim * left out one part of modular Signed-off-by: Yu Chin Fabian Lim * change model Signed-off-by: Yu Chin Fabian Lim * Make Rotary modular as well * Update bamba.md Added overview, update Model inference card and added config * Update bamba.md * Update bamba.md * Update bamba.md Minor fixes * Add docs for config and model back Signed-off-by: Antoni Viros i Martin * Add warning when using fast kernels * replaced generate example Signed-off-by: Yu Chin Fabian Lim * Address comments from PR Signed-off-by: Antoni Viros i Martin * Propagate attention fixes Signed-off-by: Antoni Viros i Martin * Fix attention interfaces to the new API Signed-off-by: Antoni Viros i Martin * Fix API for decoder layer Signed-off-by: Antoni Viros i Martin * Remove extra weights Signed-off-by: Antoni Viros i Martin --------- Signed-off-by: Yu Chin Fabian Lim Signed-off-by: Antoni Viros i Martin Co-authored-by: Gabe Goodhart Co-authored-by: Antoni Viros i Martin Co-authored-by: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Co-authored-by: Antoni Viros --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/bamba.md | 64 + docs/source/en/perf_infer_gpu_one.md | 2 + src/transformers/__init__.py | 10 + src/transformers/generation/utils.py | 1 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/bamba/__init__.py | 28 + .../models/bamba/configuration_bamba.py | 206 +++ .../bamba/convert_mamba_ssm_checkpoint.py | 273 +++ .../models/bamba/modeling_bamba.py | 1615 +++++++++++++++++ .../models/bamba/modular_bamba.py | 1303 +++++++++++++ src/transformers/utils/dummy_pt_objects.py | 21 + tests/generation/test_utils.py | 1 + tests/models/bamba/__init__.py | 0 tests/models/bamba/test_modeling_bamba.py | 603 ++++++ utils/check_config_attributes.py | 3 + 19 files changed, 4138 insertions(+) create mode 100644 docs/source/en/model_doc/bamba.md create mode 100644 src/transformers/models/bamba/__init__.py create mode 100644 src/transformers/models/bamba/configuration_bamba.py create mode 100644 src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py create mode 100644 src/transformers/models/bamba/modeling_bamba.py create mode 100644 src/transformers/models/bamba/modular_bamba.py create mode 100644 tests/models/bamba/__init__.py create mode 100644 tests/models/bamba/test_modeling_bamba.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 435b482df599cf..c30dfd3fbabc97 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -322,6 +322,8 @@ sections: - local: model_doc/albert title: ALBERT + - local: model_doc/bamba + title: Bamba - local: model_doc/bart title: BART - local: model_doc/barthez diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 3bd1c286d43240..0bd81e9d61be29 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -66,6 +66,7 @@ Flax), PyTorch, and/or TensorFlow. | [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ | | [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ | | [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ | +| [Bamba](model_doc/bamba) | ✅ | ❌ | ❌ | | [Bark](model_doc/bark) | ✅ | ❌ | ❌ | | [BART](model_doc/bart) | ✅ | ✅ | ✅ | | [BARThez](model_doc/barthez) | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/bamba.md b/docs/source/en/model_doc/bamba.md new file mode 100644 index 00000000000000..4ea8475edb885a --- /dev/null +++ b/docs/source/en/model_doc/bamba.md @@ -0,0 +1,64 @@ + + +# Bamba + + +## Overview + +Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality. + +Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba). + +## BambaConfig + +| Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings | +|-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------| +| Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True | + +[[autodoc]] BambaConfig + + + +## BambaForCausalLM + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B") +tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B") + +message = ["Mamba is a snake with following properties "] +inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False) +response = model.generate(**inputs, max_new_tokens=64) +print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) +``` + +[[autodoc]] BambaForCausalLM + - forward + +This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim). diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index cbb498070d69e5..4f9cace5b8d30d 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -39,6 +39,7 @@ FlashAttention-2 is experimental and may change considerably in future versions. FlashAttention-2 is currently supported for the following architectures: * [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration) * [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) +* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) * [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel) @@ -220,6 +221,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel) * [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration) * [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) +* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 920dc334dbb2a4..6a180a90bbbaa2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -193,6 +193,7 @@ "AutoTokenizer", ], "models.autoformer": ["AutoformerConfig"], + "models.bamba": ["BambaConfig"], "models.bark": [ "BarkCoarseConfig", "BarkConfig", @@ -1540,6 +1541,13 @@ "AutoformerPreTrainedModel", ] ) + _import_structure["models.bamba"].extend( + [ + "BambaForCausalLM", + "BambaModel", + "BambaPreTrainedModel", + ] + ) _import_structure["models.bark"].extend( [ "BarkCausalModel", @@ -5104,6 +5112,7 @@ from .models.autoformer import ( AutoformerConfig, ) + from .models.bamba import BambaConfig from .models.bark import ( BarkCoarseConfig, BarkConfig, @@ -6493,6 +6502,7 @@ AutoformerModel, AutoformerPreTrainedModel, ) + from .models.bamba import BambaForCausalLM, BambaModel, BambaPreTrainedModel from .models.bark import ( BarkCausalModel, BarkCoarseModel, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index fe634141eca09b..05627e23de11ff 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1693,6 +1693,7 @@ def _supports_default_dynamic_cache(self) -> bool: self._supports_cache_class and "jamba" not in self.__class__.__name__.lower() and "zamba" not in self.__class__.__name__.lower() + and "bamba" not in self.__class__.__name__.lower() ) def _prepare_cache_for_generation( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5eb74fab5abe71..5b3c648428359d 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -20,6 +20,7 @@ audio_spectrogram_transformer, auto, autoformer, + bamba, bark, bart, barthez, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d7d8281c2e3f03..8aba0e75b2690b 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -39,6 +39,7 @@ ("aria_text", "AriaTextConfig"), ("audio-spectrogram-transformer", "ASTConfig"), ("autoformer", "AutoformerConfig"), + ("bamba", "BambaConfig"), ("bark", "BarkConfig"), ("bart", "BartConfig"), ("beit", "BeitConfig"), @@ -337,6 +338,7 @@ ("aria_text", "AriaText"), ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"), ("autoformer", "Autoformer"), + ("bamba", "Bamba"), ("bark", "Bark"), ("bart", "BART"), ("barthez", "BARThez"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5d41ad42beea7e..770e4ea0775f76 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -39,6 +39,7 @@ ("aria_text", "AriaTextModel"), ("audio-spectrogram-transformer", "ASTModel"), ("autoformer", "AutoformerModel"), + ("bamba", "BambaModel"), ("bark", "BarkModel"), ("bart", "BartModel"), ("beit", "BeitModel"), @@ -471,6 +472,7 @@ [ # Model for Causal LM mapping ("aria_text", "AriaTextForCausalLM"), + ("bamba", "BambaForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), ("bert-generation", "BertGenerationDecoder"), diff --git a/src/transformers/models/bamba/__init__.py b/src/transformers/models/bamba/__init__.py new file mode 100644 index 00000000000000..c3920da849a333 --- /dev/null +++ b/src/transformers/models/bamba/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bamba import * + from .modeling_bamba import * + from .processing_bamba import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/bamba/configuration_bamba.py b/src/transformers/models/bamba/configuration_bamba.py new file mode 100644 index 00000000000000..f84d63ec04a9c7 --- /dev/null +++ b/src/transformers/models/bamba/configuration_bamba.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bamba model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BambaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BambaModel`]. It is used to instantiate a + BambaModel model according to the specified arguments, defining the model architecture. Instantiating a configuration + with defaults taken from [ibm-fms/Bamba-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/Bamba-9.8b-2.2T-hf). + + The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU. + The checkpoints are jointly trained by IBM, Princeton, and UIUC. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 128000): + Vocabulary size of the Bamba model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BambaModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + max_position_embeddings (`int`, *optional*, defaults to 262144): + Max cached sequence length for the model + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attn_layer_indices (`list`, *optional*): + Specifies the layer indices that will have full attention. Must contain values at most num_hidden_layers. + mamba_n_heads (`int`, *optional*, defaults to 128): + The number of mamba heads used in the v2 implementation. + mamba_d_head (`int`, *optional*, defaults to `"auto"`): + Head embeddding dimension size + mamba_n_groups (`int`, *optional*, defaults to 1): + The number of the mamba groups used in the v2 implementation. + mamba_d_state (`int`, *optional*, defaults to 256): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_chunk_size (`int`, *optional*, defaults to 256): + The chunks in which to break the sequence when doing prefill/training + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + + """ + + model_type = "bamba" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128000, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + max_position_embeddings=262144, + attention_dropout=0.0, + attn_layer_indices=None, + mamba_n_heads=128, + mamba_d_head="auto", + mamba_n_groups=1, + mamba_d_state=256, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.attention_bias = False + self.mlp_bias = False + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.attn_layer_indices = attn_layer_indices + self.rope_theta = 10000.0 + self.rope_scaling = None + self.partial_rotary_factor = 0.5 + + mamba_intermediate = mamba_expand * hidden_size + + if mamba_intermediate % mamba_n_heads != 0: + raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size") + + # for the mamba_v2, must satisfy the following + if mamba_d_head == "auto": + mamba_d_head = mamba_intermediate // mamba_n_heads + + if mamba_d_head * mamba_n_heads != mamba_intermediate: + raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size") + + self.mamba_n_heads = mamba_n_heads + self.mamba_d_head = mamba_d_head + self.mamba_n_groups = mamba_n_groups + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return [ + "attention" if (self.attn_layer_indices and i in self.attn_layer_indices) else "mamba" + for i in range(self.num_hidden_layers) + ] + + +__all__ = ["BambaConfig"] diff --git a/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py b/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py new file mode 100644 index 00000000000000..a7b8cfc782907b --- /dev/null +++ b/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed.""" + +import argparse +import json +import os +import re +from os import path +from typing import Dict, Union + +import torch +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import save_file + +from transformers import AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + +from .configuration_bamba import BambaConfig + + +def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]: + state_dict = {} + + for orig_k, param in original_sd.items(): + k = orig_k.replace("backbone", "model") + + # for embeddings + k = k.replace("embedding", "embed_tokens") + + # for mixer + k = k.replace("mixer", "mamba") + + # for final layernorm + k = k.replace("norm_f", "final_layernorm") + + # for block layernorm + k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k) + k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k) + + # for mlp + k = k.replace("mlp.fc2", "feed_forward.down_proj") + + if "mlp.fc1" in k: + param, param2 = torch.chunk(param, 2, dim=0) + k2 = k.replace("mlp.fc1", "feed_forward.gate_proj") + state_dict[k2] = param2 + k = k.replace("mlp.fc1", "feed_forward.up_proj") + + if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or ( + "out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd + ): + # then this must be a mamba + pass + else: + # for attn + # - because mixer was replaced to mamba above + k = k.replace("mamba.out_proj", "self_attn.o_proj") + if "mamba.in_proj" in k: + m, n = param.shape + d = (m - n) // 2 + param, param2, param3 = torch.split(param, [n, d, d], dim=0) + k2 = k.replace("mamba.in_proj", "self_attn.k_proj") + state_dict[k2] = param2 + k2 = k.replace("mamba.in_proj", "self_attn.v_proj") + state_dict[k2] = param3 + k = k.replace("mamba.in_proj", "self_attn.q_proj") + + state_dict[k] = param + + return state_dict + + +# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py +def convert_ssm_config_to_hf_config( + config_ssm: Dict, + **kwargs, +) -> BambaConfig: + """Convert a config from mamba_ssm to a BambaConfig from here.""" + hf_config: BambaConfig = BambaConfig(**kwargs) + + hf_config.architectures = ["BambaForCausalLM"] + + # Set important values from config and recalculate other resulting entries + hf_config.hidden_size = config_ssm["d_model"] + hf_config.intermediate_size = config_ssm["d_intermediate"] + hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head + hf_config.num_hidden_layers = config_ssm["n_layer"] + hf_config.tie_word_embeddings = config_ssm["tie_embeddings"] + + # currently this script assumes config_ssm belongs to v2 + if config_ssm["ssm_cfg"].get("layer") != "Mamba2": + raise ValueError("Conversion script only supports Mamba2") + + # Set attention values + attn_cfg = config_ssm.get("attn_cfg") + if attn_cfg: + assert attn_cfg["causal"], "Only support non-causal attention." + assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias." + assert not attn_cfg["out_proj_bias"], "Only support no out bias." + hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"] + hf_config.num_attention_heads = attn_cfg["num_heads"] + hf_config.num_key_value_heads = attn_cfg["num_heads_kv"] + + attention_layer_indices = config_ssm.get("attn_layer_idx") + if attention_layer_indices: + hf_config.attn_layer_indices = attention_layer_indices + + # Padded vocab size, mostly of 16 but 32 is also very common in different models + vocab_size = config_ssm["vocab_size"] + pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"] + if (vocab_size % pad_vocab_size_multiple) != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + hf_config.vocab_size = vocab_size + + return hf_config + + +def save_single_safetensor( + state_dict: Dict, + save_directory: str, + metadata: Dict, +): + save_file( + state_dict, + os.path.join(save_directory, SAFE_WEIGHTS_NAME), + metadata, + ) + + +def save_sharded_safetensors( + state_dict: Dict, + save_directory: str, + metadata: Dict, + max_shard_size: Union[int, str] = "5GB", +): + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) + + +# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py +def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( + mamba_ssm_checkpoint_path: str, + precision: str, + output_dir: str, + tokenizer_path: str = None, + save_model: Union[bool, str] = True, +) -> None: + # load tokenizer if provided, this will be used to set the + # token_ids in the config file + token_ids = {} + if tokenizer_path: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + for key in [ + "bos_token_id", + "eos_token_id", + "pad_token_id", + ]: + id = getattr(tokenizer, key, None) + if id: + token_ids[key] = id + + # there are some configs unsettable by mamba_ssn config, so + # if there are changes from the defaults, have to pass them into + # the function + unsettables = { + "mamba_d_head": 64, + "mamba_d_state": 128, + "mamba_n_groups": 1, + "rms_norm_eps": 1e-5, + } + + # Load and save config based on name + config_path = path.join(mamba_ssm_checkpoint_path, "config.json") + with open(config_path, "r", encoding="utf-8") as json_file: + config = json.load(json_file) + + # convert the config + hf_config = convert_ssm_config_to_hf_config( + config_ssm=config, + **token_ids, + **unsettables, + ) + hf_config.save_pretrained(output_dir) + + # Load state dict of the original model and transfer to hf model + state_dict = torch.load( + path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"), + map_location="cpu", + weights_only=True, + ) + # FIXME: allow other parameters to pass in + state_dict = convert_state_dict_from_mamba_ssm(state_dict) + + # Save new model to pytorch_dump_path + dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16) + + save_file_fn = None + if isinstance(save_model, bool) and save_model: + save_file_fn = save_single_safetensor + elif isinstance(save_model, str) and save_model == "sharded": + save_file_fn = save_sharded_safetensors + + if save_file_fn: + save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--mamba_ssm_checkpoint_directory", + type=str, + required=True, + help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.", + ) + parser.add_argument( + "-p", + "--precision", + type=str, + default="fp16", + const="fp16", + required=True, + choices=("fp32", "fp16", "bf16"), + help="The precision the model will be saved in. Select from fp32, fp16 or bf16.", + ) + parser.add_argument( + "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." + ) + parser.add_argument( + "-t", + "--tokenizer_model_path", + type=str, + default=None, + required=False, + help="Path to a the tokenizer file.", + ) + args = parser.parse_args() + + convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( + args.mamba2_checkpoint_directory, + args.precision, + args.output_dir, + ) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py new file mode 100644 index 00000000000000..c89d8d7853008d --- /dev/null +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -0,0 +1,1615 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/bamba/modular_bamba.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_bamba.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn + +import transformers.models.jamba.modeling_jamba as modeling_jamba +from transformers.activations import ACT2FN + +from ...cache_utils import Cache # we need __iter__ and __len__ of pkv +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_mamba_2_ssm_available, +) +from .configuration_bamba import BambaConfig + + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "BambaConfig" + + +# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer +class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): + super().__init__(config, batch_size, dtype, device) + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + conv_kernel_size = config.mamba_d_conv + ssm_state_size = config.mamba_d_state + + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "mamba": + self.conv_states += [ + torch.zeros( + batch_size, + (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), + conv_kernel_size, + device=device, + dtype=dtype, + ) + ] + self.ssm_states += [ + torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + ssm_state_size, + device=device, + dtype=dtype, + ) + ] + else: + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + +class BambaRotaryEmbedding(nn.Module): + def __init__( + self, + config: BambaConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class BambaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BambaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class BambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class BambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + The are a few differences between this and Mamba2Mixer: + - The variable use_precomputed_states is slightly different due to the HybridCache structure + - There's a few non-obvious bugs fixed with batching in the slow path that exist in main + - Some extra variables that our layer doesn't need have been removed + - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged + """ + + def __init__(self, config: BambaConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for Bamba will be used when running the model on a GPU") + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + # storing the states + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class BambaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class BambaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BambaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class BambaDecoderLayer(nn.Module): + def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): + super().__init__() + + num_experts = 1 + ffn_layer_class = BambaMLP if num_experts == 1 else None + self.feed_forward = ffn_layer_class(config) + self.input_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.layer_type = layer_type + if layer_type == "mamba": + self.mamba = BambaMixer(config=config, layer_idx=layer_idx) + elif layer_type == "attention": + self.self_attn = BambaAttention(config, layer_idx) + else: + raise ValueError("Invalid layer_type") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # this is a hybrid decoder layer + if self.layer_type == "mamba": + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + cache_position=cache_position, + attention_mask=attention_mask, + ) + self_attn_weights = None + elif self.layer_type == "attention": + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +BAMBA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BambaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare BambaModel outputting raw hidden-states without any specific head on top.", + BAMBA_START_DOCSTRING, +) +class BambaPreTrainedModel(PreTrainedModel): + config_class = BambaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +BAMBA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Bamba Model outputting raw hidden-states without any specific head on top.", + BAMBA_START_DOCSTRING, +) +# Adapted from transformers.models.jamba.modeling_jamba.JambaModel +class BambaModel(BambaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambaDecoderLayer`] + + Args: + config: BambaConfig + """ + + def __init__(self, config: BambaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i])) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BambaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + layer_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: HybridMambaAttentionDynamicCache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[ + :, :, -sequence_length:, : + ].to(dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = BambaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BambaForCausalLM + + >>> model = BambaForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if not empty_past_kv: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "num_logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"] diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py new file mode 100644 index 00000000000000..7fb35f48fb3b76 --- /dev/null +++ b/src/transformers/models/bamba/modular_bamba.py @@ -0,0 +1,1303 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Bamba model.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +import transformers.models.jamba.modeling_jamba as modeling_jamba +from transformers.activations import ACT2FN +from transformers.models.jamba.modeling_jamba import JambaAttentionDecoderLayer +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + rotate_half, +) +from transformers.models.mamba2.modeling_mamba2 import ( + MambaRMSNormGated, + pad_tensor_by_size, + reshape_into_chunks, + segment_sum, +) + +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_flash_attn_2_available, + is_mamba_2_ssm_available, +) +from .configuration_bamba import BambaConfig + + +if is_flash_attn_2_available(): + pass + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BambaConfig" + + +# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer +class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): + super().__init__(config, batch_size, dtype, device) + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + conv_kernel_size = config.mamba_d_conv + ssm_state_size = config.mamba_d_state + + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "mamba": + self.conv_states += [ + torch.zeros( + batch_size, + (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), + conv_kernel_size, + device=device, + dtype=dtype, + ) + ] + self.ssm_states += [ + torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + ssm_state_size, + device=device, + dtype=dtype, + ) + ] + else: + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + +class BambaRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Removes the interleaving of cos and sin from GLM + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class BambaAttention(LlamaAttention): + pass + + +class BambaRMSNormGated(MambaRMSNormGated): + pass + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class BambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + The are a few differences between this and Mamba2Mixer: + - The variable use_precomputed_states is slightly different due to the HybridCache structure + - There's a few non-obvious bugs fixed with batching in the slow path that exist in main + - Some extra variables that our layer doesn't need have been removed + - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged + """ + + def __init__(self, config: BambaConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for Bamba will be used when running the model on a GPU") + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + # storing the states + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class BambaMLP(LlamaMLP): + pass + + +class BambaRMSNorm(LlamaRMSNorm): + pass + + +class BambaDecoderLayer(JambaAttentionDecoderLayer): + def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): + super().__init__() + + del self.self_attn + + num_experts = 1 + ffn_layer_class = BambaMLP if num_experts == 1 else None + self.feed_forward = ffn_layer_class(config) + + self.layer_type = layer_type + if layer_type == "mamba": + self.mamba = BambaMixer(config=config, layer_idx=layer_idx) + elif layer_type == "attention": + self.self_attn = BambaAttention(config, layer_idx) + else: + raise ValueError("Invalid layer_type") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # this is a hybrid decoder layer + if self.layer_type == "mamba": + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + cache_position=cache_position, + attention_mask=attention_mask, + ) + self_attn_weights = None + elif self.layer_type == "attention": + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +BAMBA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BambaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare BambaModel outputting raw hidden-states without any specific head on top.", + BAMBA_START_DOCSTRING, +) +class BambaPreTrainedModel(PreTrainedModel): + config_class = BambaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +BAMBA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Bamba Model outputting raw hidden-states without any specific head on top.", + BAMBA_START_DOCSTRING, +) +# Adapted from transformers.models.jamba.modeling_jamba.JambaModel +class BambaModel(BambaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambaDecoderLayer`] + + Args: + config: BambaConfig + """ + + def __init__(self, config: BambaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i])) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BambaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + layer_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: HybridMambaAttentionDynamicCache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[ + :, :, -sequence_length:, : + ].to(dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +class BambaForCausalLM(LlamaForCausalLM): + @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BambaForCausalLM + + >>> model = BambaForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return super().forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + cache_position, + num_logits_to_keep, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if not empty_past_kv: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "num_logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 823c51a290713d..c9a49d737d092e 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1167,6 +1167,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class BambaForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BambaModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BambaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BarkCausalModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index bf56578a164c94..e85f2663624740 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2313,6 +2313,7 @@ def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1 # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the # standard cache format (e.g.gptbigcode ) models_without_standard_cache = ( + "bamba", "ctrl", "fsmt", "gptbigcode", diff --git a/tests/models/bamba/__init__.py b/tests/models/bamba/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py new file mode 100644 index 00000000000000..45819e66b73c08 --- /dev/null +++ b/tests/models/bamba/test_modeling_bamba.py @@ -0,0 +1,603 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Bamba model.""" + +import inspect +import unittest + +import pytest +from parameterized import parameterized + +from transformers import AutoTokenizer, BambaConfig, is_torch_available +from transformers.testing_utils import ( + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + BambaForCausalLM, + BambaModel, + ) + from transformers.models.bamba.modeling_bamba import ( + HybridMambaAttentionDynamicCache, + ) + + +class BambaModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + hidden_act="silu", + attention_dropout=0.0, + attn_layer_indices=None, + attn_rotary_emb=8, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + num_labels=3, + pad_token_id=0, + mamba_n_groups=1, + mamba_n_heads=16, + mamba_d_state=16, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=16, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attn_layer_indices = attn_layer_indices + self.attn_rotary_emb = attn_rotary_emb + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.pad_token_id = pad_token_id + self.scope = scope + self.mamba_n_groups = mamba_n_groups + self.mamba_n_heads = mamba_n_heads + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_labels = None + if self.use_labels: + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, input_ids, input_mask, token_labels + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + token_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def get_config(self): + # Fix for SDPA tests, force at least 4 layers + if self.num_hidden_layers < 4: + self.num_hidden_layers = 4 + if self.attn_layer_indices is None: + d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0] + if len(d) == 0: + raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.") + d = d[-1] # get the largest divisor + self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)] + + return BambaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + attention_dropout=self.attention_dropout, + attn_layer_indices=self.attn_layer_indices, + attn_rotary_emb=self.attn_rotary_emb, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + mamba_n_groups=self.mamba_n_groups, + mamba_n_heads=self.mamba_n_heads, + mamba_d_state=self.mamba_d_state, + mamba_d_conv=self.mamba_d_conv, + mamba_expand=self.mamba_expand, + mamba_chunk_size=self.mamba_chunk_size, + ) + + def create_and_check_model( + self, + config, + input_ids, + input_mask, + token_labels, + ): + model = BambaModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + input_mask, + token_labels, + ): + model = BambaForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids, labels=token_labels) + result = model(input_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + input_mask, + token_labels, + ): + # config.is_decoder = True + # config.add_cross_attention = True + model = BambaForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + # Attention: Jamba needs the cache to be initialized to return a cache! + past_key_values = HybridMambaAttentionDynamicCache( + config, input_ids.shape[0], model.dtype, device=model.device + ) + outputs = model( + input_ids, + attention_mask=input_mask, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + cache_position=torch.arange( + input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device + ), + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + +@require_torch +class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + BambaModel, + BambaForCausalLM, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (BambaForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": BambaModel, + "text-generation": BambaForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + def setUp(self): + self.model_tester = BambaModelTester(self) + self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_casual_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_initialization(self): + r""" + Overriding the test_initialization test as the A_log and D params of the Bamba mixer are initialized differently + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if "A_log" in name: + A = torch.arange(1, config.mamba_n_heads + 1, dtype=torch.float32)[None, :] + self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5)) + elif "D" in name: + D = torch.ones(config.mamba_n_heads, dtype=torch.float32) + self.assertTrue(torch.allclose(param.data, D, atol=1e-5, rtol=1e-5)) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_mismatched_shapes_have_properly_initialized_weights(self): + r""" + Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the + Bamba mixer are initialized differently and we tested that in test_initialization + """ + self.skipTest(reason="Cumbersome and redundant for Bamba") + + def test_attention_outputs(self): + r""" + Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), expected_num_attentions) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + @unittest.skip(reason="Bamba has its own special cache type") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + + def test_batching_equivalence(self): + # need to disable the tril input mask + orig = self.model_tester.use_input_mask + self.model_tester.use_input_mask = False + super().test_batching_equivalence() + self.model_tester.use_input_mask = orig + + # essentially the same test in test_utils, just adjustment for rtol for this model + @pytest.mark.generate + def test_left_padding_compatibility(self): + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + # First, filter out models that don't support left padding + # - The model must have generative capabilities + if len(self.all_generative_model_classes) == 0: + self.skipTest(reason="No generative architecture available for this model.") + + # - The model must support padding + if not self.has_attentions: + self.skipTest(reason="This model doesn't support padding.") + + # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) + decoder_only_classes = [] + for model_class in self.all_generative_model_classes: + config, _ = self.prepare_config_and_inputs_for_generate() + if config.is_encoder_decoder: + continue + else: + decoder_only_classes.append(model_class) + if len(decoder_only_classes) == 0: + self.skipTest(reason="No decoder-only architecture available for this model.") + + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't + # added support for it yet. We skip these models for now. + has_encoder_attributes = any( + attr_name + for attr_name in config.to_dict().keys() + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" + ) + if has_encoder_attributes: + self.skipTest( + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." + ) + + # Then, test left-padding + def _prepare_model_kwargs(input_ids, attention_mask, signature): + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + model_kwargs["cache_position"] = cache_position + return model_kwargs + + for model_class in decoder_only_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + input_ids = inputs_dict["input_ids"] + + # - for left padding we absolutely need to use an all ones + # attention mask, so we do not use the one in inputs_dict + attention_mask = torch.ones_like(input_ids) + + model = model_class(config).to(torch_device).eval() + signature = inspect.signature(model.forward).parameters.keys() + + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) + next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] + + # With left-padding (length 32) + # can hardcode pad_token to be 0 as we'll do attn masking anyway + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 + ) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + + # They should result in very similar logits + torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-1) + + +@slow +@require_torch +class BambaModelIntegrationTest(unittest.TestCase): + model = None + tokenizer = None + # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # Depending on the hardware we get different logits / generations + cuda_compute_capability_major_version = None + + @classmethod + def setUpClass(cls): + model_id = "ibm-fms/Bamba-9B" + cls.model = BambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) + cls.tokenizer = AutoTokenizer.from_pretrained(model_id) + + # feels a bit forced to have to do this for the generation test + cls.tokenizer.pad_token_id = cls.model.config.pad_token_id + cls.tokenizer.padding_side = "left" + + if is_torch_available() and torch.cuda.is_available(): + # 8 is for A100 / A10 and 7 for T4 + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + + def test_simple_generate(self): + # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. + # + # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, + # considering differences in hardware processing and potential deviations in generated text. + EXPECTED_TEXTS = { + # 7: "", + 8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.", + # 9: """, + } + + self.model.to(torch_device) + + input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[ + "input_ids" + ].to(torch_device) + out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10) + output_sentence = self.tokenizer.decode(out[0, :]) + self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) + + # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist + if self.cuda_compute_capability_major_version == 8: + with torch.no_grad(): + logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits + + EXPECTED_LOGITS_NO_GRAD = torch.tensor( + [ + 149., 142., 146., 142., 143., 144., 142., 145., + 142., 146., 144., 146., 147., 147., 148., 145., + 147., 145., 145., 145., 145., 144., 144., 144., + 144., 145., 147., 146., 144., 144., 148., 147., + 148., 147., 147., 147., 146., 146., 148., 148. + ], dtype=torch.bfloat16) # fmt: skip + + torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1) + + def test_simple_batched_generate_with_padding(self): + # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. + # + # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, + # considering differences in hardware processing and potential deviations in generated text. + EXPECTED_TEXTS = { + 7: [], + 8: [ + "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", + "!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the", + ], + 9: [], + } + + self.model.to(torch_device) + + inputs = self.tokenizer( + ["Hey how are you doing on this lovely evening?", "I am late! I need to"], + padding=True, + return_tensors="pt", + ).to(torch_device) + out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) + output_sentences = self.tokenizer.batch_decode(out) + self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0]) + self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1]) + + # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist + if self.cuda_compute_capability_major_version == 8: + with torch.no_grad(): + logits = self.model(input_ids=inputs["input_ids"]).logits + + EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( + [ + 149., 142., 146., 142., 143., 144., 142., 145., + 142., 146., 144., 146., 147., 147., 148., 145., + 147., 145., 145., 145., 145., 144., 144., 144., + 144., 145., 147., 146., 144., 144., 148., 147., + 148., 147., 147., 147., 146., 146., 148., 148. + ], dtype=torch.bfloat16) # fmt: skip + + EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( + [ + 182., 178., 177., 174., 176., 176., 178., 178., + 177., 179., 176., 183., 180., 182., 179., 174., + 178., 176., 176., 175., 175., 175., 174., 173., + 174., 182., 180., 176., 177., 177., 180., 176., + 178., 177., 177., 175., 176., 177., 175., 177. + ], dtype=torch.bfloat16) # fmt: skip + + torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1) + torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 420d6e6a2475d1..116e26e7834f26 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -34,6 +34,9 @@ SPECIAL_CASES_TO_ALLOW = { # 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264). # periods and offsers are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`. + "BambaConfig": [ + "attn_layer_indices", + ], "JambaConfig": [ "max_position_embeddings", "attn_layer_offset",