-
-
Notifications
You must be signed in to change notification settings - Fork 920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shifted sparse attention #973
Add shifted sparse attention #973
Conversation
42a0645
to
a6be9cb
Compare
@joecummings do you have time to rebase this onto main? If not, I can take a stab at rebasing later this week. |
yep I'll do this later today! |
cb08b2d
to
7628056
Compare
0412089
to
4135039
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. @NanoCode012 xan you take a quick look please to make sure I didn't miss anything? Thanks
) | ||
|
||
# Modify all llama derived models in one block | ||
if cfg.is_llama_derived_model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it applies to llama models, do we need to account for mistral here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings this should work with mistral and mixtral too, right?
cross_entropy=cfg.flash_attn_cross_entropy, | ||
rms_norm=cfg.flash_attn_rms_norm, | ||
if cfg.sample_packing: | ||
if cfg.device not in ["mps", "cpu"] and not inference: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, does it mean, FA won't be enabled for inference mode now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think FA was ever enabled for flash_attention. here's the original code:
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
Thanks for all your work on this @joecummings ! |
* Add s2_attn to hijack flash code * Refactor code to account for s2_attn * Add test for models utils * Add ``s2_attention`` option to llama configs * Add ``s2_attention`` option to README config * Format code to appease linter * chore: lint * Remove xpos and llama-landmark [bad merge] * add e2e smoke tests for shifted sparse attention * remove stray patch from merge * update yml with link to paper for s2_attention/longlora * fix assertion check for full fine tune * increase sequence len for tests and PR feedback updates * reduce context len to 16k for tests * reduce context len to 16k for tests * reduce batch size for larger context len and udpate test to check message * fix test for message --------- Co-authored-by: joecummings <[email protected]> Co-authored-by: Wing Lian <[email protected]>
Summary
Add shifted sparse attention (w/ flash attention) to enable longer context training w/ less memory overhead.
Paper: https://arxiv.org/pdf/2309.12307.pdf
Code: https://github.com/dvlab-research/LongLoRA/tree/main
Testing
Added test to check for raised
ValueError
ifsample_packing = True
ands2_attention = True
pytest tests/utils/test_models.py::ModelsUtilsTest::test_cfg_throws_error_with_s2_attention_and_sample_packing
Run
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
with the following config changes:[INSERT WANDB LOG HERE]
Follow-ups
embed
andnorm
during LoRA, which improves performance according to the above paper (e.g. LoRA+)