Skip to content

Commit

Permalink
Add the Bamba Model (#34982)
Browse files Browse the repository at this point in the history
* initial commit for PR

Co-authored-by: Gabe Goodhart <[email protected]>

* rename dynamic cache

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add more unit tests

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add integration test

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add integration test

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* Add modular bamba file

* Remove trainer changes from unrelated PR

* Modify modular and cofig to get model running

* Fix some CI errors and beam search

* Fix a plethora of bugs from CI/docs/etc

* Add bamba to models with special caches

* Updat to newer mamba PR for mamba sublayer

* fix test_left_padding_compatibility

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix style

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix remaining tests

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* missed this test

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* ran make style

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* move slow tag to integration obj

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* make style

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* address comments

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix modular

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* left out one part of modular

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* change model

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* Make Rotary modular as well

* Update bamba.md

Added overview, update Model inference card and added config

* Update bamba.md

* Update bamba.md

* Update bamba.md

Minor fixes

* Add docs for config and model back

Signed-off-by: Antoni Viros i Martin <[email protected]>

* Add warning when using fast kernels

* replaced generate example

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* Address comments from PR

Signed-off-by: Antoni Viros i Martin <[email protected]>

* Propagate attention fixes

Signed-off-by: Antoni Viros i Martin <[email protected]>

* Fix attention interfaces to the new API

Signed-off-by: Antoni Viros i Martin <[email protected]>

* Fix API for decoder layer

Signed-off-by: Antoni Viros i Martin <[email protected]>

* Remove extra weights

Signed-off-by: Antoni Viros i Martin <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Antoni Viros i Martin <[email protected]>
Co-authored-by: Gabe Goodhart <[email protected]>
Co-authored-by: Antoni Viros i Martin <[email protected]>
Co-authored-by: divya-kumari32 <[email protected]>
Co-authored-by: Antoni Viros <[email protected]>
  • Loading branch information
5 people authored Dec 18, 2024
1 parent 9a94dfe commit 9613933
Show file tree
Hide file tree
Showing 19 changed files with 4,138 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@
sections:
- local: model_doc/albert
title: ALBERT
- local: model_doc/bamba
title: Bamba
- local: model_doc/bart
title: BART
- local: model_doc/barthez
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Flax), PyTorch, and/or TensorFlow.
| [AriaText](model_doc/aria_text) ||||
| [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) ||||
| [Autoformer](model_doc/autoformer) ||||
| [Bamba](model_doc/bamba) ||||
| [Bark](model_doc/bark) ||||
| [BART](model_doc/bart) ||||
| [BARThez](model_doc/barthez) ||||
Expand Down
64 changes: 64 additions & 0 deletions docs/source/en/model_doc/bamba.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Bamba


## Overview

Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality.

Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba).

## BambaConfig

| Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings |
|-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------|
| Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True |

[[autodoc]] BambaConfig

<!---
## Usage Tips
Tips:
- The architecture is based on Mamba-2 models.
## BambaModel
[[autodoc]] BambaModel
- forward
-->

## BambaForCausalLM

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")

message = ["Mamba is a snake with following properties "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
```

[[autodoc]] BambaForCausalLM
- forward

This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ FlashAttention-2 is experimental and may change considerably in future versions.
FlashAttention-2 is currently supported for the following architectures:
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
Expand Down Expand Up @@ -220,6 +221,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
"AutoTokenizer",
],
"models.autoformer": ["AutoformerConfig"],
"models.bamba": ["BambaConfig"],
"models.bark": [
"BarkCoarseConfig",
"BarkConfig",
Expand Down Expand Up @@ -1540,6 +1541,13 @@
"AutoformerPreTrainedModel",
]
)
_import_structure["models.bamba"].extend(
[
"BambaForCausalLM",
"BambaModel",
"BambaPreTrainedModel",
]
)
_import_structure["models.bark"].extend(
[
"BarkCausalModel",
Expand Down Expand Up @@ -5104,6 +5112,7 @@
from .models.autoformer import (
AutoformerConfig,
)
from .models.bamba import BambaConfig
from .models.bark import (
BarkCoarseConfig,
BarkConfig,
Expand Down Expand Up @@ -6493,6 +6502,7 @@
AutoformerModel,
AutoformerPreTrainedModel,
)
from .models.bamba import BambaForCausalLM, BambaModel, BambaPreTrainedModel
from .models.bark import (
BarkCausalModel,
BarkCoarseModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,7 @@ def _supports_default_dynamic_cache(self) -> bool:
self._supports_cache_class
and "jamba" not in self.__class__.__name__.lower()
and "zamba" not in self.__class__.__name__.lower()
and "bamba" not in self.__class__.__name__.lower()
)

def _prepare_cache_for_generation(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
audio_spectrogram_transformer,
auto,
autoformer,
bamba,
bark,
bart,
barthez,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
("aria_text", "AriaTextConfig"),
("audio-spectrogram-transformer", "ASTConfig"),
("autoformer", "AutoformerConfig"),
("bamba", "BambaConfig"),
("bark", "BarkConfig"),
("bart", "BartConfig"),
("beit", "BeitConfig"),
Expand Down Expand Up @@ -337,6 +338,7 @@
("aria_text", "AriaText"),
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
("autoformer", "Autoformer"),
("bamba", "Bamba"),
("bark", "Bark"),
("bart", "BART"),
("barthez", "BARThez"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
("aria_text", "AriaTextModel"),
("audio-spectrogram-transformer", "ASTModel"),
("autoformer", "AutoformerModel"),
("bamba", "BambaModel"),
("bark", "BarkModel"),
("bart", "BartModel"),
("beit", "BeitModel"),
Expand Down Expand Up @@ -471,6 +472,7 @@
[
# Model for Causal LM mapping
("aria_text", "AriaTextForCausalLM"),
("bamba", "BambaForCausalLM"),
("bart", "BartForCausalLM"),
("bert", "BertLMHeadModel"),
("bert-generation", "BertGenerationDecoder"),
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/bamba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_bamba import *
from .modeling_bamba import *
from .processing_bamba import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading

0 comments on commit 9613933

Please sign in to comment.