-
-
Notifications
You must be signed in to change notification settings - Fork 893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert-diff-transformer CLI command / codepath #2197
Draft
djsaunde
wants to merge
19
commits into
main
Choose a base branch
from
diff-transformer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
winglian
reviewed
Dec 17, 2024
winglian
reviewed
Dec 17, 2024
* basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders <[email protected]>
c5ff9ae
to
d1ba285
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR implements the differential attention layer from the Differential Transformer paper.
Motivation and Context
We wanted to add this attention implementation to
axolotl
so users can swap out the existing attention layers in their models for this more performant version. We matched the official implementation details as closely as possible, while adopting it to play nicely with thetransformers
attention implementations.Since we were focused on being able to convert existing LLMs to having these differential attention layers, we wanted a way to not degrade the performance of the (possibly pre-trained) LLM while doing so.
To this end, the conversion process doubles the dimensionality of the query and key projections (since the differential attention requires both a positive and negative component of the attention) and (optionally; pass
--zero-init
) initializes the weights of the negative component to zero, while copying over the weights from the original attention modules to the positive components.When doing this, the converted network computes the same function as the original (pass
--debug
to confirm this), but may suffer from a vanishing gradient problem. The default behavior is thus to initialize the weights of the negative components of the differential attention layers to 0-centered normally distributed values with a small variance.Relevant links:
How has this been tested?
SmolLM2-135m on A40 Runpod instance on this feature branch. Workflow was:
--zero-init
and--debug
flags for sanity checking exact model conversion (completions, logits, losses)axolotl evaluate
command on the smallmhenrichsen/alpaca_2k_test
dataset with both the original and converted model and check that their evaluation metrics matchFor example:
Types of changes
axolotl.integrations.diff_transformer
module, which implements the differential attention layers for the Llama LLM architecture and for various attention implementations (eager, SDPA, Flash Attention 2), andaxolotl.cli.integrations.convert_diff_transformer
module (and updates toaxolotl.cli.main
), which implements theconvert-diff-transformer
CLI command, andaxolotl.cli.integrations.convert_diff_transformer.patches
(to be moved) for updatingLLAMA_ATTENTION_CLASSES
constant intransformers.models.llama.modeling_llama
.TODO