-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce MsT technologies into unsloth to extend sequence length #1082
base: nightly
Are you sure you want to change the base?
Conversation
Thank you @wdlctc ! We will review it and hopefully be able to push it in after our multimodal release! :) |
Thank you @shimmyshimmer, for your review I addon detailed training info for reference:
For more implementation, you can refer our blog: https://wdlctc.github.io/mst.html or our paper https://www.arxiv.org/abs/2407.15892 If you need other fine-tuned settings, I can try it another time. |
Rewrite it with unsloth fast_cross_entropy. We are surprised to find that integrated MST with unsloth not only improve memory behavior, but also introduce speedup. The key difference: checkpointing hidden_state of LM-head (input) instead of checkpointing logits(output)
|
@wdlctc Thanks a lot again!! I'll test it and verify all losses match! Appreciate it! |
10/14/2024: Resolve the conflicts with nightly branch |
Sorry on the delay - was planning to add this together with Vision support :) It might take a few more days! |
Oh lol I noticed I accidentally deleted this PR after I deleted the nightly branch - whoops so sorry! |
Yes! key insight is full logits is too big especially when vocabulary size is large on LLAMA3(128k) and Gemma2(256), so re-compute them on the fly can effectively reduce memory(only compute one chunk at a time and discard previous chunk) and time(for offloading). We do suggest do that row chunked, but you can also do both, row and col, as for LM-head and MLP the row and col(batch and seq) are independent. And it is effective as long context training would use local_batch_size=1. |
Description
This pull request introduces optimizations to the LLaMA model implementation, specifically targeting the language modeling head and forward pass. The main changes include:
Implement a custom _LM_head using torch.autograd.Function for more efficient forward and backward passes.
Introduce a LMheadWarpper class to manage the custom LM head.
Add minis_processing function to handle mini-batch processing of hidden states and labels.
Modify the CausalLM_fast_forward function to use the new mini-batch processing and custom LM head.
Changes
Benefits
Testing
Please ensure to test this implementation thoroughly, especially:
Performance comparison with the original implementation
Correctness of loss calculation and gradient computation
Memory usage across various input sizes