-
-
Notifications
You must be signed in to change notification settings - Fork 895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shifted sparse attention #973
Merged
winglian
merged 17 commits into
axolotl-ai-cloud:main
from
joecummings:feature/add-s2-attn
Jan 18, 2024
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
2b2fd52
Add s2_attn to hijack flash code
joecummings 1450af9
Refactor code to account for s2_attn
joecummings 60126bf
Add test for models utils
joecummings 5e66cb4
Add ``s2_attention`` option to llama configs
joecummings 0f57f30
Add ``s2_attention`` option to README config
joecummings cb335d8
Format code to appease linter
dcb5694
chore: lint
winglian 4135039
Remove xpos and llama-landmark [bad merge]
joecummings 34c62fb
add e2e smoke tests for shifted sparse attention
winglian cb899d9
remove stray patch from merge
winglian 02d1e90
update yml with link to paper for s2_attention/longlora
winglian e8ba3fe
fix assertion check for full fine tune
winglian 9292665
increase sequence len for tests and PR feedback updates
winglian 5e0890d
reduce context len to 16k for tests
winglian bee8f8c
reduce context len to 16k for tests
winglian e6e67dd
reduce batch size for larger context len and udpate test to check mes…
winglian 4f09ef4
fix test for message
winglian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,31 +256,55 @@ def load_model( | |
|
||
replace_stablelm_attn_with_flash_attn(cfg.base_model) | ||
|
||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: | ||
if cfg.device not in ["mps", "cpu"] and not inference: | ||
if cfg.sample_packing and cfg.s2_attention: | ||
raise ValueError( | ||
"Received `sample_packing=true` and `s2_attention=true`; however, \ | ||
shifted-sparse attention does not currently support sample packing." | ||
) | ||
|
||
# Modify all llama derived models in one block | ||
if cfg.is_llama_derived_model: | ||
if cfg.flash_attention: | ||
from axolotl.monkeypatch.llama_attn_hijack_flash import ( | ||
replace_llama_attn_with_flash_attn, | ||
) | ||
|
||
LOG.info("patching with flash attention for sample packing") | ||
replace_llama_attn_with_flash_attn( | ||
packed=cfg.sample_packing, | ||
cross_entropy=cfg.flash_attn_cross_entropy, | ||
rms_norm=cfg.flash_attn_rms_norm, | ||
if cfg.sample_packing: | ||
if cfg.device not in ["mps", "cpu"] and not inference: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, does it mean, FA won't be enabled for inference mode now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think FA was ever enabled for flash_attention. here's the original code: if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
) |
||
LOG.info("patching with flash attention for sample packing") | ||
replace_llama_attn_with_flash_attn( | ||
packed=True, | ||
cross_entropy=cfg.flash_attn_cross_entropy, | ||
rms_norm=cfg.flash_attn_rms_norm, | ||
) | ||
elif cfg.s2_attention: | ||
LOG.info("patching w/ flash-enabled, shifted-sparse attention") | ||
replace_llama_attn_with_flash_attn( | ||
packed=False, | ||
cross_entropy=cfg.flash_attn_cross_entropy, | ||
rms_norm=cfg.flash_attn_rms_norm, | ||
use_shifted_sparse_attn=True, | ||
) | ||
elif cfg.xformers_attention: | ||
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( | ||
hijack_llama_attention, | ||
) | ||
elif cfg.is_llama_derived_model and cfg.xformers_attention: | ||
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( | ||
hijack_llama_attention, | ||
) | ||
|
||
LOG.info("patching with xformers attention") | ||
hijack_llama_attention() | ||
elif cfg.is_llama_derived_model and cfg.sdp_attention: | ||
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention | ||
LOG.info("patching with xformers attention") | ||
hijack_llama_attention() | ||
elif cfg.sdp_attention: | ||
from axolotl.monkeypatch.llama_attn_hijack_sdp import ( | ||
hijack_llama_sdp_attention, | ||
) | ||
|
||
LOG.info("patching with sdp attention") | ||
hijack_llama_sdp_attention() | ||
LOG.info("patching with sdp attention") | ||
hijack_llama_sdp_attention() | ||
elif cfg.s2_attention: | ||
raise NotImplementedError( | ||
"Shifted-sparse attention not currently implemented without flash attention." | ||
) | ||
|
||
# Modify mistral derived models | ||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: | ||
from axolotl.monkeypatch.mistral_attn_hijack_flash import ( | ||
replace_mistral_attn_with_flash_attn, | ||
|
@@ -387,9 +411,12 @@ def load_model( | |
model_kwargs["quantization_config"] = BitsAndBytesConfig( | ||
**bnb_config, | ||
) | ||
|
||
# sample packing uses custom FA2 patch | ||
if cfg.flash_attention: | ||
if not cfg.sample_packing: | ||
if cfg.s2_attention: | ||
pass | ||
if ( | ||
winglian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cfg.is_llama_derived_model | ||
or cfg.is_falcon_derived_model | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it applies to llama models, do we need to account for mistral here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings this should work with mistral and mixtral too, right?