Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert-diff-transformer CLI command / codepath #2197

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

djsaunde
Copy link
Contributor

@djsaunde djsaunde commented Dec 17, 2024

Description

This PR implements the differential attention layer from the Differential Transformer paper.

Motivation and Context

We wanted to add this attention implementation to axolotl so users can swap out the existing attention layers in their models for this more performant version. We matched the official implementation details as closely as possible, while adopting it to play nicely with the transformers attention implementations.

Since we were focused on being able to convert existing LLMs to having these differential attention layers, we wanted a way to not degrade the performance of the (possibly pre-trained) LLM while doing so.

To this end, the conversion process doubles the dimensionality of the query and key projections (since the differential attention requires both a positive and negative component of the attention) and (optionally; pass --zero-init) initializes the weights of the negative component to zero, while copying over the weights from the original attention modules to the positive components.

When doing this, the converted network computes the same function as the original (pass --debug to confirm this), but may suffer from a vanishing gradient problem. The default behavior is thus to initialize the weights of the negative components of the differential attention layers to 0-centered normally distributed values with a small variance.

Relevant links:

How has this been tested?

SmolLM2-135m on A40 Runpod instance on this feature branch. Workflow was:

  • Convert the model to use either eager or SDPA differential attention
    • With and without --zero-init and --debug flags for sanity checking exact model conversion (completions, logits, losses)
  • Run new axolotl evaluate command on the small mhenrichsen/alpaca_2k_test dataset with both the original and converted model and check that their evaluation metrics match

For example:

$ axolotl convert-diff-transformer ../configs/smollm.yaml --output-dir ../converted-model --zero-init --debug
...
[2024-12-17 05:15:26,910] [INFO] [axolotl.cli.convert_attention.convert_diff_transformer:75] [PID:94590] [RANK:0] Converting 
to differential attention...                                                                                                 
[2024-12-17 05:15:26,910] [INFO] [axolotl.integrations.diff_transformer.convert.convert_module:97] [PID:94590] [RANK:0] Conve
rting attention layer 0: LlamaSdpaAttention to LlamaDifferentialSdpaAttention                                                
[2024-12-17 05:15:26,921] [DEBUG] [axolotl.integrations.diff_transformer.convert.copy_attention_weights:64] [PID:94590] [RANK
:0] Copied positive attention weights from LlamaSdpaAttention to LlamaDifferentialSdpaAttention                              
[2024-12-17 05:15:26,921] [INFO] [axolotl.integrations.diff_transformer.convert.convert_module:97] [PID:94590] [RANK:0] Conve
rting attention layer 1: LlamaSdpaAttention to LlamaDifferentialSdpaAttention                                                
[2024-12-17 05:15:26,930] [DEBUG] [axolotl.integrations.diff_transformer.convert.copy_attention_weights:64] [PID:94590] [RANK
:0] Copied positive attention weights from LlamaSdpaAttention to LlamaDifferentialSdpaAttention
...
ANK:0] Converted 30 attention layers to differential attention
[2024-12-17 05:15:27,181] [INFO] [axolotl.cli.convert_attention.convert_diff_transformer:85] [PID:94590] [RANK:0] Testing con
verted model...
[2024-12-17 05:15:27,785] [INFO] [axolotl.cli.convert_attention.test_inference:43] [PID:94590] [RANK:0] Prompt: The quick brown fox                                                                                                                       
[2024-12-17 05:15:28,280] [INFO] [axolotl.cli.convert_attention.convert_diff_transformer:121] [PID:94590] [RANK:0] Generations match!
Model generation:
**************************************************
The quick brown fox jumps over the lazy dog

The quick brown fox jumps over the lazy dog.

The
**************************************************

Types of changes

  • axolotl.integrations.diff_transformer module, which implements the differential attention layers for the Llama LLM architecture and for various attention implementations (eager, SDPA, Flash Attention 2), and
  • axolotl.cli.integrations.convert_diff_transformer module (and updates to axolotl.cli.main), which implements the convert-diff-transformer CLI command, and
  • Monkeypatch in axolotl.cli.integrations.convert_diff_transformer.patches (to be moved) for updating LLAMA_ATTENTION_CLASSES constant in transformers.models.llama.modeling_llama.

TODO

  • Test coverage
  • Add Flash Attention 2 implementation
  • Move monkey patch
  • Refactor conversion module as plugin
  • Add conversion with same-sized Q, K projections
  • Experiments to demonstrate value
    • Blog post

@djsaunde djsaunde self-assigned this Dec 17, 2024
outputs Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants