Skip to content

Latest commit

 

History

History
90 lines (72 loc) · 3.18 KB

README.md

File metadata and controls

90 lines (72 loc) · 3.18 KB

Differential Transformer with PyTorch Scaled Dot Product Attention

Introduction

A set of implementations for the Differential Transformer paper [1] using PyTorch's Scaled Dot Product Attention instead of the provided implementations over there:

  • Basic manual PyTorch
  • Flash Attention 2
    • Custom kernel to handle differing head_dim more efficiently
    • Original kernel that is more optimized on same head_dim

This implementation has four variations as of now:

  • Following the original Flash Attention 2 implementation more closely
  • Following the custom Flash Attention 2 implementation more closely
  • One forward pass to the attention calculations (transferable to original Flash Attention 2 implementation)
  • One forward pass to the attention calculations based on [2] (utilizing SDPA different head_dim capability)

Note:

  • RoPE is optional as I only cared about equivalency first and foremost
  • Needs external proper handling of RoPE and Attention Masks
  • It really needs benchmarks to see what is working better especially regarding both one pass versions
    • Same head_dim, more num_heads but concatenating and chunking/unbinding
    • Different head_dim, less num_heads but possibly less utilization on original Flash Attention 2

Installation

I won't distribute a pypi package, but you can use it as package by cloning the repo and installing it at root:

git clone https://github.com/vasqu/multihead-sdpadiff.git
cd multihead-sdpadiff
pip install .

Usage

import torch

from multihead_sdpadiff import (
  MultiheadSdpaDiff1,  # multiple attn passes
  MultiheadSdpaDiff2,  # two attn passes
  MultiheadSdpaDiff3,  # one attn pass (v1)
  MultiheadSdpaDiff4,  # one attn pass (v2)
)

# some shape values
bsz = 2
seq_len = 3
depth = 12
embed_dim = 768
num_heads = 12  # this will be set to half as we double them for the diff 

# random input
x = torch.randn(size=(bsz, seq_len, embed_dim))

# choose an implementation
#sdpa_mha_diff = MultiheadSdpaDiff1(embed_dim, depth, num_heads, num_heads)
#sdpa_mha_diff = MultiheadSdpaDiff2(embed_dim, depth, num_heads, num_heads)
#sdpa_mha_diff = MultiheadSdpaDiff3(embed_dim, depth, num_heads, num_heads)
sdpa_mha_diff = MultiheadSdpaDiff4(embed_dim, depth, num_heads, num_heads)

# pass and check
res = sdpa_mha_diff(x)
assert res.shape == x.shape

TODOs

  • Make it a package structure
  • Benchmark the speed/memory between the implementations
  • Transformer style RoPE + Attn Mask

Citation

[1]
@misc{ye2024differentialtransformer,
      title={Differential Transformer}, 
      author={Tianzhu Ye and Li Dong and Yuqing Xia and Yutao Sun and Yi Zhu and Gao Huang and Furu Wei},
      year={2024},
      eprint={2410.05258},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2410.05258}, 
}

[2] Thanks for MarktHart for providing another version which might be the most optimized one