-
-
Notifications
You must be signed in to change notification settings - Fork 921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mistral: Sliding Window Attention with Flash Attention and Sample Packing #732
Merged
winglian
merged 10 commits into
axolotl-ai-cloud:main
from
casper-hansen:mistral_fa_swa
Oct 16, 2023
Merged
Mistral: Sliding Window Attention with Flash Attention and Sample Packing #732
winglian
merged 10 commits into
axolotl-ai-cloud:main
from
casper-hansen:mistral_fa_swa
Oct 16, 2023
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mkeoliya
pushed a commit
to mkeoliya/axolotl
that referenced
this pull request
Dec 15, 2023
…king (axolotl-ai-cloud#732) * Implement Mistral FA + SWA + Sample Packing * Handle unbroadcastable tensor * chore: lint * Simplify _prepare_decoder_attention_mask * Uncomment window size * Upgrade flash-attn to minimum of 2.3.0 to support SWA * Add original condition to avoid error during inference * chore: lint * use torchscript to prevent oom * chore: pylint --------- Co-authored-by: Wing Lian <[email protected]>
Hi! I'm trying to fine-tune Mistral 7b with a chat dataset, packing, train on competitions only and FA. I got spammed with the warning Thanks, Toni |
djsaunde
pushed a commit
that referenced
this pull request
Dec 17, 2024
…king (#732) * Implement Mistral FA + SWA + Sample Packing * Handle unbroadcastable tensor * chore: lint * Simplify _prepare_decoder_attention_mask * Uncomment window size * Upgrade flash-attn to minimum of 2.3.0 to support SWA * Add original condition to avoid error during inference * chore: lint * use torchscript to prevent oom * chore: pylint --------- Co-authored-by: Wing Lian <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Main benefits of this PR:
window_size
to Flash Attention.Memory usage
Memory usage with SWA. The conclusion is that you save 3GB when using a sliding window mask.
_prepare_decoder_attention_mask
withwindow_size=(4096, 4096)
parameter to flash attention._prepare_decoder_attention_mask
withwindow_size=(-1, -1)
parameter to flash attention.Long context (casperhansen/longalpaca_1k_test)
I test with a long context dataset, minimum 16k tokens and maximum 32k tokens. Minimum 48GB VRAM needed to run this.
Results after a few steps:
Short context (mhenrichsen/alpaca_2k_test)
Loss on short-context datasets is tested to be the same.
Used default config in
examples/mistral/qlora.yml
.Other notes:
attention_mask
andsliding_window_mask
are not broadcastable in the first iteration after eval loss. However, this is only the case whenwandb
is enabled. This error is handled byattention_mask.shape[0] != 1
so that it does not trigger._expand_mask
and it did not work with Flash Attention. I tried other methods too, but same problem.