Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized SDPA #1515

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 66 additions & 30 deletions benchmarks/python/sdpa_vector_bench.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,94 @@
import argparse
import math

import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn

L = 16384
L = 32768
H = 32
H_k = H // 4
D = 128
dtype = mx.float16
loops = 10
bits = 8

loops = 20


def attention(q, k, v):
def _sdpa(q, k, v):
for _ in range(loops):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
ke = k[:, :, None, :, :]
ve = v[:, :, None, :, :]
s = q @ ke.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)

for i in range(loops):
q = _sdpa(q, k, v)
q = p @ ve
q = q.reshape(B, Hq, L, D)
return q


def sdpa(q, k, v):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
for _ in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
return q


def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
def quant_sdpa(q, k, v, bits=4):
for _ in range(loops):
q = mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=bits
)
return q


def quant_attention(q, k, v, bits=4):
for _ in range(loops):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]

q = q.reshape((B, Hk, Hq // Hk, L, D))
ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)

scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)

q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
q = q.reshape((B, Hq, L, D))
return q


def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v)


def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)


def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)


def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v, bits)


if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
mx.eval(q, k, v)

k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)

k = mx.dequantize(*k_quant, bits=bits)
v = mx.dequantize(*v_quant, bits=bits)

time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
24 changes: 24 additions & 0 deletions mlx/backend/metal/kernels/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,28 @@ using namespace metal;
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)

// Quantized SDPA vector instantiations
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
instantiate_kernel( \
"quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #group_size "_" #bits, \
quant_sdpa_vector_2pass_1, type, head_dim, group_size, bits)

#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
instantiate_quant_sdpa_vector(type, heads, group_size, 8)

#define instantiate_quant_sdpa_vector_group_size(type, heads) \
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
instantiate_quant_sdpa_vector_bits(type, heads, 128)

#define instantiate_quant_sdpa_vector_heads(type) \
instantiate_quant_sdpa_vector_group_size(type, 64) \
instantiate_quant_sdpa_vector_group_size(type, 128)

instantiate_quant_sdpa_vector_heads(float)
instantiate_quant_sdpa_vector_heads(bfloat16_t)
instantiate_quant_sdpa_vector_heads(float16_t)

// clang-format on
216 changes: 216 additions & 0 deletions mlx/backend/metal/kernels/sdpa_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,67 @@ template <typename T, int D>
}
}

template <typename T, typename U, int elem_per_thread, int bits>
METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
U query_sum = 0;
if (bits == 4) {
for (int i = 0; i < elem_per_thread; i += 4) {
q[i] = scale * queries[i];
q[i + 1] = scale * queries[i + 1];
q[i + 2] = scale * queries[i + 2];
q[i + 3] = scale * queries[i + 3];
query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
q[i + 1] /= 16.0f;
q[i + 2] /= 256.0f;
q[i + 3] /= 4096.0f;
}
} else if (bits == 8) {
for (int i = 0; i < elem_per_thread; i++) {
q[i] = scale * queries[i];
query_sum += q[i];
}
}
return query_sum;
}

template <typename U, int elem_per_thread, int bits>
METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
if (bits == 4) {
auto ks = (const device uint16_t*)keys;
for (int i = 0; i < elem_per_thread / 4; i++) {
k[4 * i] = ks[i] & 0x000f;
k[4 * i + 1] = ks[i] & 0x00f0;
k[4 * i + 2] = ks[i] & 0x0f00;
k[4 * i + 3] = ks[i] & 0xf000;
}
} else if (bits == 8) {
auto ks = (const device uint8_t*)keys;
for (int i = 0; i < elem_per_thread; i++) {
k[i] = ks[i];
}
}
}

template <typename U, int elem_per_thread, int bits>
METAL_FUNC void load_values(
const device uint32_t* values,
thread U* v,
U value_scale,
U value_bias) {
auto vs = (const device uint8_t*)values;
if (bits == 4) {
U s[2] = {value_scale, value_scale / 16.0f};
for (int i = 0; i < elem_per_thread / 2; i++) {
v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
}
} else if (bits == 8) {
for (int i = 0; i < elem_per_thread; i++) {
v[i] = value_scale * vs[i] + value_bias;
}
}
}

template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
Expand Down Expand Up @@ -290,3 +351,158 @@ template <typename T, int D>
}
}
}

template <typename T, int D, int group_size, int bits>
[[kernel]] void quant_sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
const device uint32_t* keys [[buffer(1)]],
const device T* key_scales [[buffer(2)]],
const device T* key_biases [[buffer(3)]],
const device uint32_t* values [[buffer(4)]],
const device T* value_scales [[buffer(5)]],
const device T* value_biases [[buffer(6)]],
device float* out [[buffer(7)]],
device float* sums [[buffer(8)]],
device float* maxs [[buffer(9)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant size_t& k_group_stride,
const constant size_t& v_group_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int BN = 8;
constexpr int BD = 4;
constexpr int elem_per_thread = D / BD;
const int stride = BN * D;
constexpr int blocks = 32;
constexpr int pack_factor = 32 / bits;

typedef float U;

thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U v[elem_per_thread];
thread U o[elem_per_thread];

threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];

// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + quad_lid * elem_per_thread;

const int kv_idx =
(block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread;
const int packed_idx = kv_idx / pack_factor;
const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size;
const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size;

keys += kv_head_idx * k_stride + packed_idx;
key_scales += k_group_idx;
key_biases += k_group_idx;
values += kv_head_idx * v_stride + packed_idx;
value_scales += v_group_idx;
value_biases += v_group_idx;

out += head_idx * blocks * D + block_idx * D + quad_lid * elem_per_thread;
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;

// Read the query and 0 the output accumulator
U query_sum = load_queries<T, U, elem_per_thread, bits>(
queries, q, static_cast<U>(scale));
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}

U max_score = -1e9;
U sum_exp_score = 0;

// For each key
for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) {
// Read the key
load_keys<U, elem_per_thread, bits>(keys, k);

// Assume D % group_size == 0 so all the keys are in the same group
U key_scale = key_scales[0];
U key_bias = key_biases[0];

// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = score * key_scale + query_sum * key_bias;
score = quad_sum(score);

// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);

max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;

U value_scale = value_scales[0];
U value_bias = value_biases[0];
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);

// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * v[i];
}

// Move the pointers to the next kv
keys += blocks * stride / pack_factor;
key_scales += blocks * stride / group_size;
key_biases += blocks * stride / group_size;
values += blocks * stride / pack_factor;
value_scales += blocks * stride / group_size;
value_biases += blocks * stride / group_size;
}

// Each thread has a partial part of the output so we need to combine them.

// First let's communicate the max and sum_exp
if (quad_lid == 0) {
max_scores[quad_gid] = max_score;
sum_exp_scores[quad_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);

// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}

// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[quad_lid * BN + quad_gid] =
o[i] * fast::exp(max_scores[quad_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);

if (quad_gid == 0) {
U output = outputs[quad_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[quad_lid * BN + j];
}
out[i] = static_cast<T>(output);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
Loading