Skip to content

Commit

Permalink
resolved some comments from the PR
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 2, 2023
1 parent cceca07 commit 1e59de5
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 156 deletions.
4 changes: 2 additions & 2 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.
|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Learned Positional Embeddings | <pre>model:<br> learned_pos_emb:&nbsp;True</pre>| 65.7 | |
| ALiBi | <pre>model:<br> attn_config:<br> alibi:&nbsp;True</pre>| 64.5 | Requires Triton or Torch attention. |
| RoPE (Dao-AILab Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_imp:&nbsp;dail</pre>| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) installed since we import their RotaryEmbedding class. |
| RoPE (Hugging<code>&nbsp;</code>Face Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_imp:&nbsp;hf</pre>| 62.3 | |
| RoPE (Dao-AILab Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;dail</pre>| 64.5 | Requires a CUDA GPU and [the flash-attn library](https://github.com/Dao-AILab/flash-attention) (v2.0.1 or higher) to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn library v2. Note that attention implementation can still be torch, triton, or flash, just that this needs the the flash-attn library (v2.0.1 or higher) installed since we import their RotaryEmbedding class. |
| RoPE (Hugging<code>&nbsp;</code>Face Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;hf</pre>| 62.3 | |

### Can I finetune using PEFT / LoRA?
- The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so:
Expand Down
20 changes: 12 additions & 8 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def is_flash_v1_installed():
return version.parse(flash_attn.__version__) < version.parse('2.0.0')


# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
if is_flash_v1_installed():
import transformers
transformers.utils.is_flash_attn_available = lambda: False
Expand Down Expand Up @@ -593,12 +596,14 @@ def forward(
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']

query = query.view(*(query.shape[:-1]), -1, self.head_dim)
key = key.view(*(key.shape[:-1]), -1, self.head_dim)
assert query.shape[:2] == key.shape[:2]
assert query.shape[:2] == key.shape[:2]
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)

if rotary_emb_w_meta_info['imp'] == 'dail':
value = value.view(*(value.shape[:-1]), -1, self.head_dim)
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(query,
Expand All @@ -607,8 +612,7 @@ def forward(
max_seqlen=seq_len)
[key, value] = torch.unbind(kv, dim=2)

value = value.view(*(value.shape[:-2]),
self.kv_n_heads * self.head_dim)
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['imp'] == 'hf':
(cos, sin) = rotary_emb(value, seq_len)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
Expand All @@ -620,8 +624,8 @@ def forward(
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(*(query.shape[:-2]), self.d_model)
key = key.view(*(key.shape[:-2]), self.kv_n_heads * self.head_dim)
query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)

context, attn_weights, past_key_value = self.attn_fn(
query,
Expand Down
52 changes: 27 additions & 25 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
'rope_theta': 10000,
'rope_impl': 'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
'xpos_scale_base': 512,
},
'rope_hf_config': {
'type': 'no_scaling',
'factor': 1.0,
},
}


class MPTBlock(nn.Module):

Expand All @@ -30,30 +55,7 @@ def __init__(
**kwargs: Any,
):
if attn_config is None:
attn_config = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
'rope_theta': 10000,
'rope_imp': 'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
'xpos_scale_base': 512,
},
'rope_hf_config': {
'type': 'no_scaling',
'factor': 1.0,
},
}
attn_config = attn_config_defaults

if ffn_config is None:
ffn_config = {
Expand All @@ -70,7 +72,7 @@ def __init__(
# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id',
'alibi_bias_max', 'rope', 'rope_theta', 'rope_imp',
'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl',
'rope_dail_config', 'rope_hf_config'
}
attn_config_subset_for_attn_class = {
Expand Down
36 changes: 7 additions & 29 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,7 @@

from transformers import PretrainedConfig

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
'rope_theta': 10000,
'rope_imp': 'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
'xpos_scale_base': 512,
},
'rope_hf_config': {
'type': 'no_scaling',
'factor': 1.0,
},
}
from llmfoundry.models.layers.blocks import attn_config_defaults

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
Expand Down Expand Up @@ -108,7 +85,7 @@ def __init__(
alibi_bias_max (int): The maximum value of the alibi bias.
rope (bool): Whether to use rotary positional embeddings.
rope_theta (int): The base frequency for rope.
rope_imp (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py).
rope_impl (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py).
rope_dail_config (Dict): The configuration for the dail implementation of rope.
type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf).
pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding.
Expand Down Expand Up @@ -189,6 +166,7 @@ def _set_config_defaults(self, config: Dict[str, Any],
if k not in config:
config[k] = v
elif isinstance(v, dict):
# recursively set default values for any sub-dicts
config[k] = self._set_config_defaults(
config[k] if (config[k] is not None) else {}, v)
return config
Expand Down Expand Up @@ -233,21 +211,21 @@ def _validate_config(self) -> None:
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.'
)
if self.attn_config['rope'] and (self.attn_config['rope_imp']
if self.attn_config['rope'] and (self.attn_config['rope_impl']
not in ['dail', 'hf']):
raise ValueError(
'If rope is being used then rope_imp should be either "dail", or "hf".'
'If rope is being used then rope_impl should be either "dail", or "hf".'
)
if self.attn_config['rope'] and (
self.attn_config['rope_imp']
self.attn_config['rope_impl']
== 'hf') and self.attn_config['rope_hf_config']['type'] not in [
'no_scaling', 'linear', 'dynamic'
]:
raise ValueError(
'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".'
)
if self.attn_config['rope'] and (
self.attn_config['rope_imp']
self.attn_config['rope_impl']
== 'dail') and (self.attn_config['rope_dail_config']['type']
not in ['original', 'xpos']):
raise ValueError(
Expand Down
63 changes: 34 additions & 29 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,45 +86,44 @@
log = logging.getLogger(__name__)


def _rotary_embedding(config: MPTConfig):
rope_head_dim = config.d_model // config.n_heads
if config.attn_config['rope_imp'] == 'dail':
def _rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int,
rope_dail_config: dict, rope_hf_config: dict,
max_seq_len: int):
if rope_impl == 'dail':
return DAILRotaryEmbedding(
dim=rope_head_dim,
base=config.attn_config['rope_theta'],
base=rope_theta,
interleaved=False,
scale_base=config.attn_config['rope_dail_config']['xpos_scale_base']
if (config.attn_config['rope_dail_config']['type']
== 'xpos') else None,
pos_idx_in_fp32=config.attn_config['rope_dail_config']
['pos_idx_in_fp32'],
scale_base=rope_dail_config['xpos_scale_base'] if
(rope_dail_config['type'] == 'xpos') else None,
pos_idx_in_fp32=rope_dail_config['pos_idx_in_fp32'],
device=
'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif config.attn_config['rope_imp'] == 'hf':
if config.attn_config['rope_hf_config']['type'] == 'no_scaling':
elif rope_impl == 'hf':
if rope_hf_config['type'] == 'no_scaling':
return HFRotaryEmbedding(
rope_head_dim,
max_position_embeddings=config.max_seq_len,
base=config.attn_config['rope_theta'],
max_position_embeddings=max_seq_len,
base=rope_theta,
device=
'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif config.attn_config['rope_hf_config']['type'] == 'linear':
elif rope_hf_config['type'] == 'linear':
return HFLinearScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=config.max_seq_len,
base=config.attn_config['rope_theta'],
scaling_factor=config.attn_config['rope_hf_config']['factor'],
max_position_embeddings=max_seq_len,
base=rope_theta,
scaling_factor=rope_hf_config['factor'],
device=
'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
elif config.attn_config['rope_hf_config']['type'] == 'dynamic':
elif rope_hf_config['type'] == 'dynamic':
return HFDynamicNTKScalingRotaryEmbedding(
rope_head_dim,
max_position_embeddings=config.max_seq_len,
base=config.attn_config['rope_theta'],
scaling_factor=config.attn_config['rope_hf_config']['factor'],
max_position_embeddings=max_seq_len,
base=rope_theta,
scaling_factor=rope_hf_config['factor'],
device=
'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
)
Expand Down Expand Up @@ -184,10 +183,16 @@ def __init__(self, config: MPTConfig):
self.norm_f = norm_class(config.d_model, device=config.init_device)

self.rope = config.attn_config['rope']
self.rope_imp = None
self.rope_impl = None
if self.rope:
self.rope_imp = config.attn_config['rope_imp']
self.rotary_embedding = _rotary_embedding(config)
self.rope_impl = config.attn_config['rope_impl']
self.rotary_embedding = _rotary_embedding(
rope_head_dim=config.d_model // config.n_heads,
rope_impl=self.rope_impl,
rope_theta=config.attn_config['rope_theta'],
rope_dail_config=config.attn_config['rope_dail_config'],
rope_hf_config=config.attn_config['rope_hf_config'],
max_seq_len=self.config.max_seq_len)

if config.init_device != 'meta':
log.info(
Expand Down Expand Up @@ -453,7 +458,7 @@ def forward(
f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
)

if self.learned_pos_emb or (self.rope and self.rope_imp == 'hf'):
if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'):
pos = torch.arange(
past_position,
S + past_position,
Expand All @@ -469,16 +474,16 @@ def forward(
)
if self.learned_pos_emb:
x = x + self.wpe(pos)
elif self.rope and self.rope_imp == 'hf':
elif self.rope and self.rope_impl == 'hf':
rotary_emb_w_meta_info = {
'imp': self.rope_imp,
'imp': self.rope_impl,
'rotary_emb': self.rotary_embedding,
'offset_info': pos,
'seq_len': S + past_position,
}
elif self.rope and self.rope_imp == 'dail':
elif self.rope and self.rope_impl == 'dail':
rotary_emb_w_meta_info = {
'imp': self.rope_imp,
'imp': self.rope_impl,
'rotary_emb': self.rotary_embedding,
'offset_info': past_position,
'seq_len': S + past_position,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def allclose_helper(t0: torch.Tensor,
'alibi': False,
'rope': True,
'rope_theta': 10000,
'rope_imp': 'dail',
'rope_impl': 'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
Expand All @@ -41,7 +41,7 @@ def allclose_helper(t0: torch.Tensor,
'alibi': False,
'rope': True,
'rope_theta': 10000,
'rope_imp': 'hf',
'rope_impl': 'hf',
'rope_hf_config': {
'type': 'no_scaling',
'factor': 1.0,
Expand All @@ -68,7 +68,7 @@ def test_attn_impl(attn_impl_0: str,
if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'):
pytest.xfail('flash attn does not support alibi')

if rope and (pos_emb_config['rope_imp']
if rope and (pos_emb_config['rope_impl']
== 'dail') and (not is_flash_v2_installed()):
pytest.skip('dail implementation of rope requires flash attention 2.')

Expand Down Expand Up @@ -140,11 +140,11 @@ def gen_bias(attn_impl: str):
)
rotary_emb_w_meta_info = {
'imp':
pos_emb_config['rope_imp'],
pos_emb_config['rope_impl'],
'rotary_emb':
rotary_embedding,
'offset_info':
pos if (pos_emb_config['rope_imp'] == 'hf') else 0,
pos if (pos_emb_config['rope_impl'] == 'hf') else 0,
'seq_len':
s,
}
Expand Down
Loading

0 comments on commit 1e59de5

Please sign in to comment.