Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing the gen_attention_mask_in_length function to handle the case when sequence id is -1 due to attention masking #940

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
04dd334
Merge pull request #1 from mosaicml/main
ShashankMosaicML Oct 9, 2023
87b2fdc
Merge pull request #8 from mosaicml/main
ShashankMosaicML Oct 27, 2023
c9a42e4
Merge pull request #12 from mosaicml/main
ShashankMosaicML Nov 6, 2023
ddea9ee
Merge branch 'mosaicml:main' into main
ShashankMosaicML Nov 6, 2023
0bcd8ee
Merge pull request #13 from mosaicml/main
ShashankMosaicML Nov 8, 2023
f209b58
Merge pull request #14 from mosaicml/main
ShashankMosaicML Nov 14, 2023
ec4378d
Merge pull request #15 from mosaicml/main
ShashankMosaicML Nov 15, 2023
b436706
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 2, 2023
bcace03
..
ShashankMosaicML Dec 8, 2023
cf4aa58
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 11, 2023
7c35ce8
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 13, 2023
0a8ebfb
..
ShashankMosaicML Dec 15, 2023
6f18a33
..
ShashankMosaicML Dec 15, 2023
f42d585
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 16, 2023
2f3f53c
Merge branch 'mosaicml:main' into main
ShashankMosaicML Dec 19, 2023
77b975f
..
ShashankMosaicML Dec 20, 2023
e28cfbe
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 1, 2024
800c6f8
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 2, 2024
922ef13
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 3, 2024
d36f5f7
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 5, 2024
d524531
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 17, 2024
2b2f3d8
..
ShashankMosaicML Jan 17, 2024
25795b5
undoing prev commit
ShashankMosaicML Jan 17, 2024
624a339
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 18, 2024
1c25b98
Merge branch 'mosaicml:main' into main
ShashankMosaicML Jan 29, 2024
d25cf2e
Merge branch 'mosaicml:main' into main
ShashankMosaicML Feb 1, 2024
1cc4505
Merge branch 'mosaicml:main' into main
ShashankMosaicML Feb 3, 2024
93fe393
Merge branch 'mosaicml:main' into main
ShashankMosaicML Feb 3, 2024
9bc62c9
fixing the gen_attention_mask_in_length function to handle the case w…
ShashankMosaicML Feb 5, 2024
45d832f
Update modeling_mpt.py
ShashankMosaicML Feb 5, 2024
f1ded14
..
ShashankMosaicML Feb 5, 2024
b55a572
Merge branch 'main' into shashank/fix_gen_attention_mask_in_length_at…
dakinggg Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
raise ValueError(
f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).'
)
if attention_mask is not None:
# -1 is used to pad the sequence_id where attention mask is False (https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249).
# We replace those -1 with 0 to prevent `torch.nn.functional.one_hot(sequence_id)` in the next line from failing.
# We apply the attention mask again after the one_hot operation.
sequence_id = sequence_id.masked_fill(~attention_mask, 0)
attention_mask_in_length = torch.nn.functional.one_hot(sequence_id)
if attention_mask is not None:
attention_mask_in_length = attention_mask_in_length.masked_fill(
Expand Down
4 changes: 4 additions & 0 deletions tests/models/layers/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def test_attn_impl(attn_impl_0: str,
# zero out the last third of the attention mask
# to simulate padding
attention_mask[:, -s // 3:] = 0
if sequence_id is not None:
sequence_id = sequence_id.masked_fill(
~attention_mask, -1
) # Similar to how we set sequence id for padded tokens: https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249

def gen_bias(attn_impl: str):
causal = True
Expand Down
Loading