Skip to content

Commit

Permalink
Add flash attention for gpt_bigcode (#26479)
Browse files Browse the repository at this point in the history
* added flash attention of gpt_bigcode

* changed docs

* Update src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py

* add FA-2 docs

* oops

* Update docs/source/en/perf_infer_gpu_one.md Last Nit

Co-authored-by: Arthur <[email protected]>

* fix

* oops

* remove padding_mask

* change getattr->hasattr logic

* changed .md file

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
4 people authored Oct 31, 2023
1 parent 9dc4ce9 commit b5db8ca
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 23 deletions.
39 changes: 39 additions & 0 deletions docs/source/en/model_doc/gpt_bigcode.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,45 @@ The main differences compared to GPT2.

You can read more about the optimizations in the [original pull request](https://github.com/huggingface/transformers/pull/22575)

## Combining Starcoder and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

```bash
pip install -U flash-attn --no-build-isolation
```

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)

To load and run a model using Flash Attention 2, refer to the snippet below:

```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")

>>> prompt = "def hello_world():"

>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)

>>> generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False)
>>> tokenizer.batch_decode(generated_ids)[0]
'def hello_world():\n print("hello world")\n\nif __name__ == "__main__":\n print("hello world")\n<|endoftext|>'
```

### Expected speedups

Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `bigcode/starcoder` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/starcoder-speedup.png">
</div>


## GPTBigCodeConfig

[[autodoc]] GPTBigCodeConfig
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ We natively support Flash Attention 2 for the following models:
- Llama
- Mistral
- Falcon
- [GPTBigCode (Starcoder)](model_doc/gpt_bigcode#)

You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*

Expand Down
Loading

0 comments on commit b5db8ca

Please sign in to comment.