diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 2dc2da7f8..9fb6f36d2 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -1,18 +1,15 @@ import mlx.core as mx import numpy as np +from mlx.utils import tree_map from time_utils import time_fn -L = 16 +L = 65536 H = 32 H_k = 32 // 4 D = 128 def attention(q, k, v): - k = mx.quantize(k) - v = mx.quantize(v) - k = mx.dequantize(*k) - v = mx.dequantize(*v) B, Hq, L, D = q.shape _, Hk, S, _ = k.shape q = q.reshape(B, Hk, Hq // Hk, L, D) @@ -25,21 +22,31 @@ def attention(q, k, v): def sdpa(q, k, v): - k = mx.quantize(k, bits=8) - v = mx.quantize(v, bits=8) - k = mx.dequantize(*k, bits=8) - v = mx.dequantize(*v, bits=8) return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) def quant_sdpa(q, k, v): - k = mx.quantize(k, bits=8) - v = mx.quantize(v, bits=8) return mx.fast.quantized_scaled_dot_product_attention( q, *k, *v, scale=1.0, mask=None, bits=8 ) +def quant_attention(q, k, v): + B, Hq, L, D = q.shape + Hk = k[0].shape[1] + + q = q.reshape((B, Hk, Hq // Hk, L, D)) + k = tree_map(lambda x: mx.expand_dims(x, axis=2), k) + v = tree_map(lambda x: mx.expand_dims(x, axis=2), v) + + scores = mx.quantized_matmul(q, *k, transpose=True) + scores = mx.softmax(scores, axis=-1) + + out = mx.quantized_matmul(scores, *v, transpose=False) + out = out.reshape((B, Hq, L, D)) + return out + + def time_self_attention_primitives(q, k, v): time_fn(attention, q, k, v) @@ -52,34 +59,22 @@ def time_self_attention_quant_sdpa(q, k, v): time_fn(quant_sdpa, q, k, v) +def time_self_attention_quant_primitives(q, k, v): + time_fn(quant_attention, q, k, v) + + if __name__ == "__main__": mx.random.seed(3) - # q = mx.random.uniform(shape=(1, H, 1, D)) - # k = mx.random.uniform(shape=(1, H_k, L, D)) - # v = mx.random.uniform(shape=(1, H_k, L, D)) - q = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/queries.npy")) - k = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/keys.npy")) - v = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/values.npy")) - print(q.dtype) - print(q.shape, k.shape, v.shape) + q = mx.random.uniform(shape=(1, H, 1, D)) + k = mx.random.uniform(shape=(1, H_k, L, D)) + v = mx.random.uniform(shape=(1, H_k, L, D)) mx.eval(q, k, v) k_quant = mx.quantize(k) v_quant = mx.quantize(v) mx.eval(k_quant, v_quant) - # time_self_attention_sdpa(q, k, v) - # time_self_attention_quant_sdpa(q, k_quant, v_quant) - # time_self_attention_primitives(q, k, v) - q_sdpa = quant_sdpa(q, k, v) - print(q_sdpa) - # o_attention = attention(q, k, v) - # print(o_attention) - # np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5) - o_sdpa = sdpa(q, k, v) - print(o_sdpa) - np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5) - # print(o_sdpa[..., :64]) - # print() - # print(o_attention[..., :64]) - # np.testing.assert_allclose(o_sdpa, o_attention) + time_self_attention_sdpa(q, k, v) + time_self_attention_quant_sdpa(q, k_quant, v_quant) + time_self_attention_primitives(q, k, v) + time_self_attention_quant_primitives(q, k_quant, v_quant) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 6643c380d..63e6d7d14 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -9,8 +9,6 @@ #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" #include "mlx/fast_primitives.h" -#include - namespace mlx::core::fast { namespace { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index f01e4fe2c..cdc594bea 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -10,8 +10,6 @@ #include "mlx/ops.h" #include "mlx/transforms.h" -#include - namespace mlx::core::fast { std::vector Custom::vjp( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 83981481a..0b3947567 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -169,26 +169,22 @@ void init_fast(nb::module_& parent_module) { nb::sig( "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. - - Supports: - - * `Multi-Head Attention `_ - * `Grouped Query Attention `_ - * `Multi-Query Attention `_ - - Note: The softmax operation is performed in ``float32`` regardless of - the input precision. + A fast implementation of multi-head attention where the keys and values are quantized. - Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` - and ``v`` inputs should not be pre-tiled to match ``q``. + see :func:`scaled_dot_product_attention` for more details. Args: q (array): Input query array. k (array): Input keys array. + k_scales (array): Scales for the quantized keys array. + k_biases (array): Biases for the quantized keys array. v (array): Input values array. + v_scales (array): Scales for the quantized values array. + v_biases (array): Biases for the quantized values array. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) mask (array, optional): An additive mask to apply to the query-key scores. + group_size (int): The group size used in the KV quantization. + bits (int): The bits used in the KV quantization. Returns: array: The output array. )pbdoc");