-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
] Integrate Flash attention 2 in most used models
#25598
Changes from 74 commits
8bb77a1
2e18421
fe5795e
4bd15e2
49fe318
f5d440b
50491e8
7df78c0
0e30d13
ad8b905
3c31f10
2628bf3
56d0b49
20d1b37
a82f1ca
c72e8ff
66823f9
41f8f3d
a64a1a9
67e3fc2
8444ab6
8b1c2df
c3ebcd2
1c212d8
4618701
85ec946
2248f20
0881ced
a8a1b2d
2be3e03
b6d3e58
b47e85c
db8bd64
58848ab
3f73557
baae736
55f6140
10d5c1b
3fb221a
a931aeb
68a1204
36e0d6e
2beeb68
7b5da2c
b99a582
adaed45
7f06af6
2d36c6f
a663fa4
9d3693f
65ae59c
43185b5
c61157e
2f17792
65c3861
165a503
5abc702
5069e4a
11400d8
ace7939
fe9b16d
6174c06
acfc954
33a0f62
ee8ba20
e28fb0b
025727c
8f7e400
3259392
57a077b
e62b0b8
7419438
3ba5e98
585e463
ec0f8b9
3e5ea35
4bb1bc5
3ea4633
b67c21e
5b73557
48e3bcf
0461384
8d72a66
73b2f07
fb7654c
a737bde
80951ae
6f7ff42
257a633
360da70
1d91bc4
8ecab97
7c5720f
28b82e2
84b5793
949172f
825c7e0
1af232c
d7f16c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,142 @@ rendered properly in your Markdown viewer. | |
|
||
In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu). | ||
|
||
## Flash Attention 2 | ||
|
||
<Tip> | ||
|
||
Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future. | ||
|
||
Comment on lines
+24
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
did not know this was planned 😄 If not let's just not say anything There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well it is all about providing a single meaningful API to users and avoid confusing them. In PyTorch 2.2 (hopefully not too late!), we'll be in a state where FA2 will be supported by SDPA so basically a duplicate of this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As felix said, the goal in the future would be to have an unified API through |
||
</Tip> | ||
|
||
Flash Attention 2 can considerably speedup the training and inference speed of transformer based models. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) from Tri Dao et al. The scientific paper of Flash attention can be found [here](https://arxiv.org/abs/2205.14135). | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature. | ||
|
||
We natively support Flash Attention 2 for some models, currently supported architectures are: | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
- Llama | ||
- Falcon | ||
|
||
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub! | ||
|
||
And they can be used for inference and training, including training with padding tokens - which is currently not supported for `BetterTransformer` API below. | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
<Tip> | ||
|
||
Flash Attention 2 can only be used for models using fp16 or bf16 dtype, and can be run only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature. | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
</Tip> | ||
|
||
### Quick usage | ||
|
||
To enable Flash Attention 2 in your model, simply add `use_flash_attn_2` in `from_pretrained` arguments | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
torch_dtype=torch.bfloat16, | ||
use_flash_attn_2=True, | ||
) | ||
``` | ||
|
||
And use it for generation or fine-tuning. | ||
|
||
### Expected speedups | ||
|
||
You can benefit from considerable speedup for fine-tuning and inference, especially for long sequence length. | ||
However, note that due to the fact that Flash Attention does not support computing attention scores with padd tokens under the hood, we need to manually pad / unpad the attention scores for batched inference when the sequence contains padd tokens. This leads to an important slowdown for batched `generate` with padd tokens. To overcome this, one should use Flash Attention without padd tokens in the sequence for training (e.g. by packing a dataset, i.e. concatenating sequences until reaching the maximum sequence length) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO you should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (applies to the entire document) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'd add a link to a doc explaining that in our docs and/or to some of our examples that do it (for ex I think that's what's happening here in run_clm) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep added few lines in that direction, let me know how does that sounds to you
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Below is the expected speedup you can get for a simple forward pass on `tiiuae/falcon-7b` with a sequence length of 4096 and various batch sizes, without padd tokens: | ||
|
||
<div style="text-align: center"> | ||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png"> | ||
</div> | ||
|
||
Below is the expected speedup you can get for a simple forward pass on `meta-llama/Llama-7b-hf` with a sequence length of 4096 and various batch sizes, without padd tokens: | ||
|
||
<div style="text-align: center"> | ||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png"> | ||
</div> | ||
|
||
TODO: @younesbelkada add more figures and cases where FA fails. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll address that a bit later, I need to check first if we can merge younesbelkada#5 |
||
|
||
Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequenc lengths without facing CUDA OOM issues. | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
### Advanced usage | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very nice examples! |
||
|
||
You can combine this feature with many exisiting feature for model optimization. Check out few examples below: | ||
|
||
### Combining Flash Attention 2 and 8-bit models | ||
|
||
You can combine this feature together with 8-bit quantization: | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
load_in_8bit=True, | ||
use_flash_attn_2=True, | ||
) | ||
``` | ||
|
||
### Combining Flash Attention 2 and 4-bit models | ||
|
||
You can combine this feature together with 4-bit quantization: | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
load_in_4bit=True, | ||
use_flash_attn_2=True, | ||
) | ||
``` | ||
|
||
### Combining Flash Attention 2 and PEFT | ||
|
||
You can combine this feature together with PEFT for training adapters using Flash Attention 2 under the hood: | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
from peft import LoraConfig | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
load_in_4bit=True, | ||
use_flash_attn_2=True, | ||
) | ||
|
||
lora_config = LoraConfig( | ||
r=8, | ||
task_type="CAUSAL_LM" | ||
) | ||
|
||
model.add_adapter(lora_config) | ||
|
||
... # train your model | ||
``` | ||
|
||
## BetterTransformer | ||
|
||
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ | |
is_accelerate_available, | ||
is_auto_gptq_available, | ||
is_bitsandbytes_available, | ||
is_flash_attn_available, | ||
is_offline_mode, | ||
is_optimum_available, | ||
is_peft_available, | ||
|
@@ -1116,6 +1117,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix | |
is_parallelizable = False | ||
supports_gradient_checkpointing = False | ||
|
||
# Flash Attention 2 support | ||
_supports_flash_attn_2 = False | ||
|
||
@property | ||
def dummy_inputs(self) -> Dict[str, torch.Tensor]: | ||
""" | ||
|
@@ -1239,6 +1243,83 @@ def can_generate(cls) -> bool: | |
return False | ||
return True | ||
|
||
@classmethod | ||
def _check_and_enable_flash_attn_2( | ||
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None | ||
) -> PretrainedConfig: | ||
Comment on lines
+1247
to
+1249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good error raising here |
||
""" | ||
If you don't know about Flash Attention, check out the official repository of flash attention: | ||
https://github.com/Dao-AILab/flash-attention | ||
|
||
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this | ||
specific section of the documentation to learn more about it: | ||
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models | ||
|
||
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in | ||
half precision and not ran on CPU. | ||
|
||
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model | ||
can initialize the correct attention module | ||
""" | ||
if not cls._supports_flash_attn_2: | ||
raise ValueError( | ||
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to " | ||
"request support for this architecture." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd add a link to https://github.com/huggingface/transformers/issues/new |
||
|
||
if not is_flash_attn_available(): | ||
raise ImportError( | ||
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" | ||
" installing it." | ||
) | ||
else: | ||
is_flash_greater_than_2 = version.parse(importlib.metadata.version("flash_attn")) > version.parse("2.0.0") | ||
if not is_flash_greater_than_2: | ||
raise ValueError( | ||
"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe print the current version they have installed currently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense! |
||
|
||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False) | ||
|
||
if _is_bettertransformer: | ||
raise ValueError( | ||
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" | ||
) | ||
Comment on lines
+1285
to
+1288
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this just toggle it off with an |
||
|
||
if torch_dtype is None: | ||
warnings.warn( | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if the convention changed, but originally we favored There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the heads up ! Changed it with logger.warning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. flagging as it seems to still be a warning.warn |
||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: | ||
raise ValueError( | ||
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to" | ||
" unexpected behaviour." | ||
) | ||
|
||
if device_map is None: | ||
if torch.cuda.is_available(): | ||
warnings.warn( | ||
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU" | ||
" after initializing it on CPU with `model.to('cuda')`." | ||
) | ||
else: | ||
raise ValueError( | ||
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. " | ||
"This is not supported. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " | ||
"or initialising the model on CPU and then moving it to GPU." | ||
) | ||
elif ( | ||
device_map is not None | ||
and isinstance(device_map, dict) | ||
and ("cpu" in device_map.values() or "disk" in device_map.values()) | ||
): | ||
raise ValueError( | ||
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " | ||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." | ||
) | ||
Comment on lines
+1318
to
+1320
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so basically we can't use it if you don't have enough gpu VRAM. It's not 100% clear for me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, also not supported if you excplictly want to do CPU / Disk offloading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually why is this not supported? It shouldn't be a problem to support Flash Attention + cpu offload IMO (we're supporting it for diffusers) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice to support indeed, would enable a bunch of larger models There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's update the comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would require some work as you need to intantiate a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I propose to do it in a follow up PR |
||
config._flash_attn_2_enabled = True | ||
return config | ||
|
||
def enable_input_require_grads(self): | ||
""" | ||
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping | ||
|
@@ -2369,6 +2450,7 @@ def from_pretrained( | |
variant = kwargs.pop("variant", None) | ||
_adapter_model_path = kwargs.pop("_adapter_model_path", None) | ||
adapter_name = kwargs.pop("adapter_name", "default") | ||
use_flash_attn_2 = kwargs.pop("use_flash_attn_2", False) | ||
|
||
if is_fsdp_enabled(): | ||
low_cpu_mem_usage = True | ||
|
@@ -2980,6 +3062,9 @@ def from_pretrained( | |
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: | ||
init_contexts.append(init_empty_weights()) | ||
|
||
if use_flash_attn_2: | ||
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map) | ||
|
||
with ContextManagers(init_contexts): | ||
model = cls(config, *model_args, **model_kwargs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to test this and make sure building this new docker image works as expected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment in #26268 (comment)