Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Barron committed Dec 6, 2024
1 parent 12a4d89 commit 3507c10
Showing 1 changed file with 51 additions and 48 deletions.
99 changes: 51 additions & 48 deletions python/tests/test_fast_sdpa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import unittest
from itertools import product

import mlx.core as mx
import mlx_tests
Expand Down Expand Up @@ -113,61 +114,63 @@ def test_fast_sdpa(self):
R = 1
Dk = 128
scale = float(1.0 / np.sqrt(128.0))
q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32)
k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32)
v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32)
q = mx.random.normal(shape=(1, 32, R, Dk))
k = mx.random.normal(shape=(1, 32, L, Dk))
v = mx.random.normal(shape=(1, 32, L, Dk))

q_mlx = mx.array(q_npy)
k_mlx = mx.array(k_npy)
v_mlx = mx.array(v_npy)
reference = mlx_primitives_sdpa(q, k, v, scale)

reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale)
o = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)

o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale, mask=None
)

self.assertListEqual(list(reference.shape), list(o_mlx.shape))
self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4))
self.assertListEqual(list(reference.shape), list(o.shape))
self.assertTrue(mx.allclose(o, reference, atol=1e-4))

B = 1
H = 32
dtypes = [np.float32]
if self.is_apple_silicon:
dtypes.append(np.half)

for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
for DO_GQA in [0, 1]:
for DTYPE in dtypes:
n_kv_heads = 8 if DO_GQA else 32
q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE)
k_npy = np.random.normal(
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
).astype(DTYPE)
v_npy = np.random.normal(
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
).astype(DTYPE)

q_mlx = mx.array(q_npy)
k_mlx = mx.array(k_npy)
v_mlx = mx.array(v_npy)

reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale
)

self.assertListEqual(list(reference.shape), list(o_mlx.shape))
rtol = 1e-5
atol = 1e-1

if SEQUENCE_LENGTH > 500:
rtol = 1e-2

if DTYPE == np.half:
rtol = 1e-2

self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
tests = product(
[1, 7, 9, 32, 63, 67, 129, 2000], # sequence length
[False, True], # gqa
[mx.float32, mx.float16],
[4, 8], # bits
)
for sequence_length, do_gqa, dtype, bits in tests:
with self.subTest(
sequence_length=sequence_length, gqa=do_gqa, dtype=dtype, bits=bits
):
n_kv_heads = 8 if do_gqa else 32
q = mx.random.normal(shape=(B, H, R, Dk), dtype=dtype)
k = mx.random.normal(
shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype
)
v = mx.random.normal(
shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype
)

k_q = mx.quantize(k, bits=bits)
v_q = mx.quantize(v, bits=bits)
k_d = mx.dequantize(*k_q, bits=bits)
v_d = mx.dequantize(*v_q, bits=bits)

reference = mlx_primitives_sdpa_with_gqa(q, k_d, v_d, scale)
o = mx.fast.scaled_dot_product_attention(q, k_d, v_d, scale=scale)
o_q = mx.fast.quantized_scaled_dot_product_attention(
q, *k_q, *v_q, scale=scale, bits=bits
)

self.assertListEqual(list(reference.shape), list(o.shape))
rtol = 1e-5
atol = 1e-1

if sequence_length > 500:
rtol = 1e-2

if dtype == mx.float16:
rtol = 1e-2

# np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol)
self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol))
self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol))

q = mx.random.normal(shape=(1, 32, 1, Dk))
k = mx.random.normal(shape=(1, 32, 32, Dk))
Expand Down

0 comments on commit 3507c10

Please sign in to comment.