Skip to content

Commit

Permalink
[core ] Integrate Flash attention 2 in most used models (#25598)
Browse files Browse the repository at this point in the history
* v1

* oops

* working v1

* fixup

* add some TODOs

* fixup

* padding support + try with module replacement

* nit

* alternative design

* oops

* add `use_cache` support for llama

* v1 falcon

* nit

* a bit of refactor

* nit

* nits nits

* add v1 padding support falcon (even though it seemed to work before)

* nit

* falcon works

* fixup

* v1 tests

* nit

* fix generation llama flash

* update tests

* fix tests + nits

* fix copies

* fix nit

* test- padding mask

* stype

* add more mem efficient support

* Update src/transformers/modeling_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* fixup

* nit

* fixup

* remove it from config when saving

* fixup

* revert docstring

* add more checks

* use values

* oops

* new version

* fixup

* add same trick for falcon

* nit

* add another test

* change tests

* fix issues with GC and also falcon

* fixup

* oops

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Arthur <[email protected]>

* add init_rope

* updates

* fix copies

* fixup

* fixup

* more clarification

* fixup

* right padding tests

* add docs

* add FA in docker image

* more clarifications

* add some figures

* add todo

* rectify comment

* Change to FA2

* Update docs/source/en/perf_infer_gpu_one.md

Co-authored-by: Arthur <[email protected]>

* split in two lines

* change test name

* add more tests

* some clean up

* remove `rearrange` deps

* add more docs

* revert changes on dockerfile

* Revert "revert changes on dockerfile"

This reverts commit 8d72a66.

* revert changes on dockerfile

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <[email protected]>

* address some comments

* docs

* use inheritance

* Update src/transformers/testing_utils.py

Co-authored-by: Lysandre Debut <[email protected]>

* fixup

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_utils.py

* final comments

* clean up

* style

* add cast + warning for PEFT models

* fixup

---------

Co-authored-by: Felix Marty <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Lysandre Debut <[email protected]>
  • Loading branch information
5 people authored Sep 22, 2023
1 parent dcbfd93 commit 368a58e
Show file tree
Hide file tree
Showing 14 changed files with 934 additions and 14 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/perf_infer_gpu_many.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Note: A multi GPU setup can use the majority of the strategies described in the

</Tip>

## Flash Attention 2

Flash Attention 2 integration also works in a multi-GPU setup, check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2)

## BetterTransformer

[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
Expand Down
148 changes: 148 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,154 @@ rendered properly in your Markdown viewer.

In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu).

## Flash Attention 2

<Tip>

Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future.

</Tip>

Flash Attention 2 can considerably speed up transformer-based models' training and inference speed. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. The scientific paper on Flash Attention can be found [here](https://arxiv.org/abs/2205.14135).

Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature.

We natively support Flash Attention 2 for the following models:

- Llama
- Falcon

You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*

<Tip>

Flash Attention 2 can only be used when the models' dtype is `fp16` or `bf16` and runs only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature.

</Tip>

### Quick usage

To enable Flash Attention 2 in your model, add `use_flash_attention_2` in the `from_pretrained` arguments:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
)
```

And use it for generation or fine-tuning.

### Expected speedups

You can benefit from considerable speedups for fine-tuning and inference, especially for long sequences. However, since Flash Attention does not support computing attention scores with padding tokens under the hood, we must manually pad / unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens.

To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516).

Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes without padding tokens:

Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes, without padding tokens:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png">
</div>

Below is the expected speedup you can get for a simple forward pass on [`meta-llama/Llama-7b-hf`](https://hf.co/meta-llama/Llama-7b-hf) with a sequence length of 4096 and various batch sizes, without padding tokens:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png">
</div>

For sequences with padding tokens (training with padding tokens or generating with padding tokens), we need to unpad / pad the input sequences to compute correctly the attention scores. For relatively small sequence length, on pure forward pass, this creates an overhead leading to a small speedup (below 30% of the input has been filled with padding tokens).

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-small-seqlen-padding.png">
</div>

But for large sequence length you can benefit from interesting speedup for pure inference (also training)

Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequence lengths without facing CUDA OOM issues. It can lead up to memory reduction up to 20 for large sequence length. Check out [the official flash attention repository](https://github.com/Dao-AILab/flash-attention) for more details.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
</div>


### Advanced usage

You can combine this feature with many exisiting feature for model optimization. Check out few examples below:

### Combining Flash Attention 2 and 8-bit models

You can combine this feature together with 8-bit quantization:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attention_2=True,
)
```

### Combining Flash Attention 2 and 4-bit models

You can combine this feature together with 4-bit quantization:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
)
```

### Combining Flash Attention 2 and PEFT

You can combine this feature together with PEFT for training adapters using Flash Attention 2 under the hood:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from peft import LoraConfig

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
)

lora_config = LoraConfig(
r=8,
task_type="CAUSAL_LM"
)

model.add_adapter(lora_config)

... # train your model
```

## BetterTransformer

[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ For additional information on tf32 vs other precisions, please refer to the foll
[RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004390803) and
[A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004543189).

## Flash Attention 2

You can speedup the training throughput by using Flash Attention 2 integration in transformers. Check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2) to learn more about how to load a model with Flash Attention 2 modules.

## Optimizer choice

The most common optimizer used to train transformer models is Adam or AdamW (Adam with weight decay). Adam achieves
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,9 @@ def to_diff_dict(self) -> Dict[str, Any]:

self.dict_torch_dtype_to_str(serializable_config_dict)

if "_flash_attn_2_enabled" in serializable_config_dict:
del serializable_config_dict["_flash_attn_2_enabled"]

return serializable_config_dict

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -871,6 +874,8 @@ def to_dict(self) -> Dict[str, Any]:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "_flash_attn_2_enabled" in output:
del output["_flash_attn_2_enabled"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
86 changes: 86 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_accelerate_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_flash_attn_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
Expand Down Expand Up @@ -1116,6 +1117,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_parallelizable = False
supports_gradient_checkpointing = False

# Flash Attention 2 support
_supports_flash_attn_2 = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -1239,6 +1243,84 @@ def can_generate(cls) -> bool:
return False
return True

@classmethod
def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
) -> PretrainedConfig:
"""
If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
can initialize the correct attention module
"""
if not cls._supports_flash_attn_2:
raise ValueError(
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)

if not is_flash_attn_available():
raise ImportError(
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it."
)
else:
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
is_flash_greater_than_2 = flash_attention_version > version.parse("2.0.0")
if not is_flash_greater_than_2:
raise ValueError(
f"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed - detected version {flash_attention_version}"
)

_is_bettertransformer = getattr(cls, "use_bettertransformer", False)

if _is_bettertransformer:
raise ValueError(
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
)

if torch_dtype is None:
logger.warning(
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to"
" unexpected behaviour."
)

if device_map is None:
if torch.cuda.is_available():
logger.warning(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
config._flash_attn_2_enabled = True
return config

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
Expand Down Expand Up @@ -2374,6 +2456,7 @@ def from_pretrained(
variant = kwargs.pop("variant", None)
_adapter_model_path = kwargs.pop("_adapter_model_path", None)
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)

if is_fsdp_enabled():
low_cpu_mem_usage = True
Expand Down Expand Up @@ -2977,6 +3060,9 @@ def from_pretrained(
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

if use_flash_attention_2:
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)

with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ def __init__(self, config: OpenLlamaConfig):
self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
Expand Down
Loading

0 comments on commit 368a58e

Please sign in to comment.