Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant transposes for rope rotation #807

Merged
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
04dd334
Merge pull request #1 from mosaicml/main
ShashankMosaicML Oct 9, 2023
87b2fdc
Merge pull request #8 from mosaicml/main
ShashankMosaicML Oct 27, 2023
c9a42e4
Merge pull request #12 from mosaicml/main
ShashankMosaicML Nov 6, 2023
ddea9ee
Merge branch 'mosaicml:main' into main
ShashankMosaicML Nov 6, 2023
0bcd8ee
Merge pull request #13 from mosaicml/main
ShashankMosaicML Nov 8, 2023
f209b58
Merge pull request #14 from mosaicml/main
ShashankMosaicML Nov 14, 2023
ec4378d
Merge pull request #15 from mosaicml/main
ShashankMosaicML Nov 15, 2023
b436706
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 2, 2023
bcace03
..
ShashankMosaicML Dec 8, 2023
cf4aa58
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 11, 2023
7c35ce8
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 13, 2023
0a8ebfb
..
ShashankMosaicML Dec 15, 2023
6f18a33
..
ShashankMosaicML Dec 15, 2023
f42d585
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 16, 2023
6535d04
..
ShashankMosaicML Dec 16, 2023
fff3b48
Merge branch 'main' into shashank/fix_redundant_transposes_rope
ShashankMosaicML Dec 19, 2023
e2d33eb
..
ShashankMosaicML Dec 20, 2023
fe4afd1
Merge branch 'main' into shashank/fix_redundant_transposes_rope
ShashankMosaicML Dec 20, 2023
e9fabad
merging
ShashankMosaicML Dec 20, 2023
d2602b1
Merge branch 'main' into shashank/fix_redundant_transposes_rope
ShashankMosaicML Dec 20, 2023
d94baa6
Update llmfoundry/models/layers/attention.py
ShashankMosaicML Dec 20, 2023
1b4eb95
..
ShashankMosaicML Dec 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn as nn
import transformers
from einops import rearrange
from packaging import version
from torch import nn
Expand All @@ -34,6 +35,10 @@ def is_flash_v1_installed():
return version.parse(flash_attn.__version__) < version.parse('2.0.0')


def check_transformers_version(hf_version: str):
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
return version.parse(transformers.__version__) >= version.parse(hf_version)


# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
Expand Down Expand Up @@ -627,14 +632,20 @@ def forward(
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
(cos, sin) = rotary_emb(value, seq_len)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(query, key, cos, sin,
offset_info)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if check_transformers_version('4.36'):
query, key = apply_rotary_pos_emb(query,
key,
cos,
sin,
offset_info,
unsqueeze_dim=2)
else:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(query, key, cos, sin,
offset_info)
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
Expand Down
Loading