Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow flash attention 2 and upgrade to transformers 4.34.1 #672

Merged
merged 30 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6872b7e
llama unit tests pass again
dakinggg Oct 13, 2023
d1efda1
more special casing in tokenizer equivalence check
dakinggg Oct 13, 2023
fdc89f0
move the patch to module init
dakinggg Oct 13, 2023
633ddfc
fix addedtoken -> str
dakinggg Oct 13, 2023
7acace5
precommit
dakinggg Oct 13, 2023
a0e709b
flash2 unit tests pass
dakinggg Oct 14, 2023
4d944b9
precommit
dakinggg Oct 14, 2023
adac5c0
fix tests
dakinggg Oct 14, 2023
6ae2393
precommit
dakinggg Oct 14, 2023
4b6dd78
fix test
dakinggg Oct 14, 2023
4e7c628
try again
dakinggg Oct 14, 2023
c1d00e9
add lazy load option
dakinggg Oct 14, 2023
0ea0d7d
fix test
dakinggg Oct 14, 2023
242f7c9
precommit
dakinggg Oct 14, 2023
d002bd8
add gc collect
dakinggg Oct 14, 2023
e60eef1
Merge branch 'main' into tr34-flash2
dakinggg Oct 17, 2023
dedd4cd
Merge branch 'main' into tr34-flash2
dakinggg Oct 18, 2023
bf6d947
Merge branch 'main' into tr34-flash2
dakinggg Oct 22, 2023
b7a0a96
updates for the patch release
dakinggg Oct 22, 2023
f519adb
precommit
dakinggg Oct 22, 2023
523a329
add documentation for flash attention options
dakinggg Oct 22, 2023
91ec66a
small fixes
dakinggg Oct 22, 2023
1604586
remove default fa setting
dakinggg Oct 22, 2023
834e46f
fix test
dakinggg Oct 22, 2023
317cb1a
fix again
dakinggg Oct 22, 2023
5ec9cea
Merge branch 'main' into tr34-flash2
dakinggg Oct 23, 2023
6db6315
use patch
dakinggg Oct 24, 2023
528340f
remove commented out code
dakinggg Oct 24, 2023
aaa82b2
lazy load comment
dakinggg Oct 24, 2023
a544d4a
fix readme typo
dakinggg Oct 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import torch

try:
# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
import transformers

from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
Expand All @@ -14,8 +19,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
flash_attn_fn, is_flash_v1_installed,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
Expand All @@ -24,6 +29,8 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
if is_flash_v1_installed():
transformers.utils.is_flash_attn_available = lambda: False

except ImportError as e:
try:
Expand Down
22 changes: 20 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights

try:
Expand Down Expand Up @@ -95,12 +94,28 @@ def __init__(self, om_model_config: Union[DictConfig,
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
False)
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')

config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)

# This is not how you are supposed to set this, but transformers currently only
# supports enabling flash attention 2 when using the from_pretrained API.
# We need to support it for both from_pretrained and from_config, so we have to
# set the private attribute here. This will just skip all of transformers'
# validation logic that it is ok to use flash attention 2, so we check
# whether it is installed above, and whether the chosen config supports it here.
# https://github.com/huggingface/transformers/issues/26878
config._flash_attn_2_enabled = use_flash_attention_2

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
Expand Down Expand Up @@ -200,6 +215,9 @@ def __init__(self, om_model_config: Union[DictConfig,
)
from transformers.models.llama.modeling_llama import \
LlamaAttention

from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def convert_ids_to_tokens(
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
return str(self.added_tokens_decoder[ids])

return self._convert_id_to_token(ids)

Expand All @@ -171,7 +171,7 @@ def convert_ids_to_tokens(
if index in self.added_tokens_decoder:
tokens.append(self.encoding.decode(current_stream))
current_stream = []
tokens.append(self.added_tokens_decoder[index])
tokens.append(str(self.added_tokens_decoder[index]))
else:
current_stream.append(index)

Expand Down
61 changes: 54 additions & 7 deletions scripts/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ This README walks through pretraining and finetuning a large language model usin
#### Table of Contents
1. [Part 1: LLM Pretraining](#llmpretraining)
1. [Installation](#installation)
2. [Dataset Preparation](#datasetpreparation)
3. [How to start single and multi-node pretraining](#howtostartpretraining)
2. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Dataset Preparation](#datasetpreparation)
1. [How to start single and multi-node pretraining](#howtostartpretraining)
1. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Using a dataset on the HuggingFace Hub](#hfdataset)
2. [Using a local dataset](#localdataset)
3. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
3. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
4. [FAQ: Optimizing Performance](#optimizingperformance)
1. [Using a local dataset](#localdataset)
1. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
1. [Using Flash Attention](#flashattention)
1. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
1. [FAQ: Optimizing Performance](#optimizingperformance)

# Part 1: LLM Pretraining <a name="llmpretraining"></a>

Expand Down Expand Up @@ -332,6 +333,52 @@ train_loader:
...
```

# Using Flash Attention <a name="flashattention"></a>

Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.

For MPT, you can specify Flash Attention in your YAML like so:
```yaml
model:
name: mpt_causal_lm
...
attn_config:
# Will use either V1 or V2 depending on what is installed
# "triton" will use the Triton implementation
attn_impl: flash
...
```

If loading MPT from the HuggingFace Hub, you can specify Flash Attention in your YAML like so:
```yaml
model:
name: hf_causal_lm
pretrained_model_name_or_path: mosaicml/mpt-7b
...
config_overrides:
# Will use either V1 or V2 depending on what is installed
# "triton" will use the Triton implementation
attn_config:
attn_impl: flash
...
```

For any HuggingFace model that supports Flash Attention (e.g. Llama and Mistral), you can specify Flash Attention in your YAML like so:
```yaml
model:
name: hf_causal_lm
use_flash_attention_2: True # Will be automatically set to True if Flash Attention V2 is installed and the model supports it
irenedea marked this conversation as resolved.
Show resolved Hide resolved
...
```
HuggingFace models currently only support Flash Attention V2.

For Llama specifically, we have another option if you would like to use the Triton implementation of Flash Attention. You can specify this in your YAML like so:
```yaml
model:
name: llama
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
attention_patch_type: triton
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
...
```

# FAQ: How many GPUs do I need to train a LLM? <a name="howmanygpus"></a>
This is a complicated question in general, but if we assume that you are using FSDP with `FULL_SHARD`,
Expand Down
7 changes: 7 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
import gc
import logging
import os
import sys
Expand Down Expand Up @@ -216,6 +217,11 @@ def main(cfg: DictConfig) -> Trainer:
os.environ[
'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}'

# Set CUDA lazy loading
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
cuda_load_lazy: bool = cfg.pop('cuda_load_lazy', True)
if cuda_load_lazy:
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'

# Set seed first
seed: int = pop_config(cfg, 'seed', must_exist=True)
reproducibility.seed_all(seed)
Expand Down Expand Up @@ -634,6 +640,7 @@ def main(cfg: DictConfig) -> Trainer:
print('Logging config')
log_config(logged_cfg)
torch.cuda.empty_cache()
gc.collect()

# Eval first if requested
if eval_first and trainer.state.timestamp.batch.value == 0:
Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.33,<4.34',
'transformers>=4.34.1,<4.35',
'mosaicml-streaming>=0.6,<0.7',
'torch>=1.13.1,<2.1.1',
'datasets>=2.14.5,<2.15',
Expand Down Expand Up @@ -114,9 +114,10 @@
extra_deps['all-cpu'] = set(
dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key)
extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps
if key != 'gpu-flash2')
extra_deps['all-flash2'] = set(
dep for key, deps in extra_deps.items() for dep in deps if key != 'gpu')
if key not in {'gpu-flash2', 'all-cpu'})
extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items()
for dep in deps
if key not in {'gpu', 'all', 'all-cpu'})

setup(
name=_PACKAGE_NAME,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,49 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase,
tokenizer1.__dict__['init_kwargs'].pop('auto_map', None)
tokenizer2.__dict__['init_kwargs'].pop('auto_map', None)

# Additional special tokens do not match between original tokenizer and loaded tokenizer due to transformers
# constructor differences
additional_special_tokens_1 = {
t if isinstance(t, str) else t.content
for t in tokenizer1.__dict__.pop('_additional_special_tokens', [])
}
additional_special_tokens_2 = {
t if isinstance(t, str) else t.content
for t in tokenizer2.__dict__.pop('_additional_special_tokens', [])
}
# Also pop it out of init_kwargs
tokenizer1.__dict__['init_kwargs'].pop('additional_special_tokens', None)
tokenizer2.__dict__['init_kwargs'].pop('additional_special_tokens', None)
tokenizer1.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
tokenizer2.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
# If the additional special tokens are the same (or a subset of each other), or if one of them is empty, then we are good
assert additional_special_tokens_1.issubset(
additional_special_tokens_2) or additional_special_tokens_2.issubset(
additional_special_tokens_1)

# The special token attributes may be strings or they may be AddedToken objects, so we just check string values
# First check that they have the same attrs
assert tokenizer1.SPECIAL_TOKENS_ATTRIBUTES == tokenizer2.SPECIAL_TOKENS_ATTRIBUTES
# Then check that the values are the same
for special_token_attr in tokenizer1.SPECIAL_TOKENS_ATTRIBUTES:
# Skip additional_special_tokens because we already checked it above
if special_token_attr == 'additional_special_tokens':
continue

# The init_kwargs can change between the original tokenizer and the loaded tokenizer,
# so we just pop them
tokenizer1.__dict__['init_kwargs'].pop(special_token_attr, None)
tokenizer2.__dict__['init_kwargs'].pop(special_token_attr, None)

attr1 = tokenizer1.__dict__.pop('_' + special_token_attr, None)
attr2 = tokenizer2.__dict__.pop('_' + special_token_attr, None)
if attr1 is None and attr2 is None:
continue

attr_value1 = attr1 if isinstance(attr1, str) else attr1.content
attr_value2 = attr2 if isinstance(attr2, str) else attr2.content
assert attr_value1 == attr_value2

assert tokenizer1.__dict__ == tokenizer2.__dict__


Expand Down
Loading
Loading