Skip to content

Commit

Permalink
working qsdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Barron committed Dec 6, 2024
1 parent e047fd9 commit 12a4d89
Show file tree
Hide file tree
Showing 8 changed files with 853 additions and 46 deletions.
96 changes: 66 additions & 30 deletions benchmarks/python/sdpa_vector_bench.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,94 @@
import argparse
import math

import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn

L = 16384
L = 32768
H = 32
H_k = H // 4
D = 128
dtype = mx.float16
loops = 10
bits = 8

loops = 20


def attention(q, k, v):
def _sdpa(q, k, v):
for _ in range(loops):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
ke = k[:, :, None, :, :]
ve = v[:, :, None, :, :]
s = q @ ke.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)

for i in range(loops):
q = _sdpa(q, k, v)
q = p @ ve
q = q.reshape(B, Hq, L, D)
return q


def sdpa(q, k, v):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
for _ in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
return q


def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
def quant_sdpa(q, k, v, bits=4):
for _ in range(loops):
q = mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=bits
)
return q


def quant_attention(q, k, v, bits=4):
for _ in range(loops):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]

q = q.reshape((B, Hk, Hq // Hk, L, D))
ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)

scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)

q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
q = q.reshape((B, Hq, L, D))
return q


def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v)


def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)


def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)


def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v, bits)


if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
mx.eval(q, k, v)

k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)

k = mx.dequantize(*k_quant, bits=bits)
v = mx.dequantize(*v_quant, bits=bits)

time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
29 changes: 29 additions & 0 deletions mlx/backend/metal/kernels/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,33 @@ using namespace metal;
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)

// Quantized SDPA vector instantiations
#define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_" #head_dim "_" #group_size "_" #bits, \
name, type, head_dim, group_size, bits)

#define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \
instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits) \
instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits)

#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4) \
instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8)

#define instantiate_quant_sdpa_vector_group_size(type, heads) \
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
instantiate_quant_sdpa_vector_bits(type, heads, 128)

#define instantiate_quant_sdpa_vector_heads(type) \
instantiate_quant_sdpa_vector_group_size(type, 64) \
instantiate_quant_sdpa_vector_group_size(type, 96) \
instantiate_quant_sdpa_vector_group_size(type, 128)

instantiate_quant_sdpa_vector_heads(float)
instantiate_quant_sdpa_vector_heads(bfloat16_t)
instantiate_quant_sdpa_vector_heads(float16_t)

// clang-format on
Loading

0 comments on commit 12a4d89

Please sign in to comment.