Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shifted sparse attention #973

Merged
merged 17 commits into from
Jan 18, 2024

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Dec 17, 2023

Summary

Add shifted sparse attention (w/ flash attention) to enable longer context training w/ less memory overhead.

Paper: https://arxiv.org/pdf/2309.12307.pdf
Code: https://github.com/dvlab-research/LongLoRA/tree/main

Testing

  • Added test to check for raised ValueError if sample_packing = True and s2_attention = True
    pytest tests/utils/test_models.py::ModelsUtilsTest::test_cfg_throws_error_with_s2_attention_and_sample_packing

  • Run accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml with the following config changes:

datasets:
  - path: Yukang/LongAlpaca-12k # From LongLoRA paper
    type: alpaca
sequence_len:  65536
s2_attention: true

[INSERT WANDB LOG HERE]

Follow-ups

@joecummings joecummings marked this pull request as ready for review December 18, 2023 00:49
@winglian winglian requested a review from NanoCode012 December 18, 2023 15:00
@joecummings joecummings force-pushed the feature/add-s2-attn branch 3 times, most recently from 42a0645 to a6be9cb Compare December 19, 2023 03:54
@winglian
Copy link
Collaborator

@joecummings do you have time to rebase this onto main? If not, I can take a stab at rebasing later this week.

@joecummings
Copy link
Contributor Author

@joecummings do you have time to rebase this onto main? If not, I can take a stab at rebasing later this week.

yep I'll do this later today!

@joecummings joecummings force-pushed the feature/add-s2-attn branch 2 times, most recently from cb08b2d to 7628056 Compare January 11, 2024 15:04
Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. @NanoCode012 xan you take a quick look please to make sure I didn't miss anything? Thanks

src/axolotl/utils/models.py Outdated Show resolved Hide resolved
)

# Modify all llama derived models in one block
if cfg.is_llama_derived_model:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it applies to llama models, do we need to account for mistral here as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings this should work with mistral and mixtral too, right?

cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
if cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, does it mean, FA won't be enabled for inference mode now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think FA was ever enabled for flash_attention. here's the original code:

    if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
        if cfg.device not in ["mps", "cpu"] and not inference:
            from axolotl.monkeypatch.llama_attn_hijack_flash import (
                replace_llama_attn_with_flash_attn,
            )

            LOG.info("patching with flash attention for sample packing")
            replace_llama_attn_with_flash_attn(
                packed=cfg.sample_packing,
                cross_entropy=cfg.flash_attn_cross_entropy,
                rms_norm=cfg.flash_attn_rms_norm,
            )

src/axolotl/utils/models.py Show resolved Hide resolved
tests/e2e/patched/test_llama_s2_attention.py Outdated Show resolved Hide resolved
tests/e2e/patched/test_llama_s2_attention.py Outdated Show resolved Hide resolved
tests/utils/test_models.py Outdated Show resolved Hide resolved
@winglian winglian requested a review from NanoCode012 January 18, 2024 13:13
@winglian winglian merged commit 1d70f24 into axolotl-ai-cloud:main Jan 18, 2024
7 checks passed
@winglian
Copy link
Collaborator

Thanks for all your work on this @joecummings !

@joecummings joecummings deleted the feature/add-s2-attn branch January 18, 2024 16:53
djsaunde pushed a commit that referenced this pull request Dec 17, 2024
* Add s2_attn to hijack flash code

* Refactor code to account for s2_attn

* Add test for models utils

* Add ``s2_attention`` option to llama configs

* Add ``s2_attention`` option to README config

* Format code to appease linter

* chore: lint

* Remove xpos and llama-landmark [bad merge]

* add e2e smoke tests for shifted sparse attention

* remove stray patch from merge

* update yml with link to paper for s2_attention/longlora

* fix assertion check for full fine tune

* increase sequence len for tests and PR feedback updates

* reduce context len to 16k for tests

* reduce context len to 16k for tests

* reduce batch size for larger context len and udpate test to check message

* fix test for message

---------

Co-authored-by: joecummings <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants