Skip to content

vasqu/multihead-sdpadiff

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Differential Transformer with PyTorch Scaled Dot Product Attention

Introduction

A set of implementations for the Differential Transformer paper [1] using PyTorch's Scaled Dot Product Attention instead of the provided implementations over there:

  • Basic manual PyTorch
  • Flash Attention 2
    • Custom kernel to handle differing head_dim more efficiently
    • Original kernel that is more optimized on same head_dim

This implementation has four variations as of now:

  • Following the original Flash Attention 2 implementation more closely
  • Following the custom Flash Attention 2 implementation more closely
  • One forward pass to the attention calculations (transferable to original Flash Attention 2 implementation)
  • One forward pass to the attention calculations based on [2] (utilizing SDPA different head_dim capability)

Note:

  • RoPE is optional as I only cared about equivalency first and foremost
  • Needs external proper handling of RoPE and Attention Masks
  • It really needs benchmarks to see what is working better especially regarding both one pass versions
    • Same head_dim, more num_heads but concatenating and chunking/unbinding
    • Different head_dim, less num_heads but possibly less utilization on original Flash Attention 2

Installation

I won't distribute a pypi package, but you can use it as package by cloning the repo and installing it at root:

git clone https://github.com/vasqu/multihead-sdpadiff.git
cd multihead-sdpadiff
pip install .

Usage

import torch

from multihead_sdpadiff import (
  MultiheadSdpaDiff1,  # multiple attn passes
  MultiheadSdpaDiff2,  # two attn passes
  MultiheadSdpaDiff3,  # one attn pass (v1)
  MultiheadSdpaDiff4,  # one attn pass (v2)
)

# some shape values
bsz = 2
seq_len = 3
depth = 12
embed_dim = 768
num_heads = 12  # this will be set to half as we double them for the diff 

# random input
x = torch.randn(size=(bsz, seq_len, embed_dim))

# choose an implementation
#sdpa_mha_diff = MultiheadSdpaDiff1(embed_dim, depth, num_heads, num_heads)
#sdpa_mha_diff = MultiheadSdpaDiff2(embed_dim, depth, num_heads, num_heads)
#sdpa_mha_diff = MultiheadSdpaDiff3(embed_dim, depth, num_heads, num_heads)
sdpa_mha_diff = MultiheadSdpaDiff4(embed_dim, depth, num_heads, num_heads)

# pass and check
res = sdpa_mha_diff(x)
assert res.shape == x.shape

TODOs

  • Make it a package structure
  • Benchmark the speed/memory between the implementations
  • Transformer style RoPE + Attn Mask

Citation

[1]
@misc{ye2024differentialtransformer,
      title={Differential Transformer}, 
      author={Tianzhu Ye and Li Dong and Yuqing Xia and Yutao Sun and Yi Zhu and Gao Huang and Furu Wei},
      year={2024},
      eprint={2410.05258},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2410.05258}, 
}

[2] Thanks for MarktHart for providing another version which might be the most optimized one

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages