diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 1df48bc7f..ad3262e7c 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -1,5 +1,6 @@ import math import unittest +from itertools import product import mlx.core as mx import mlx_tests @@ -113,61 +114,63 @@ def test_fast_sdpa(self): R = 1 Dk = 128 scale = float(1.0 / np.sqrt(128.0)) - q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32) - k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) - v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) + q = mx.random.normal(shape=(1, 32, R, Dk)) + k = mx.random.normal(shape=(1, 32, L, Dk)) + v = mx.random.normal(shape=(1, 32, L, Dk)) - q_mlx = mx.array(q_npy) - k_mlx = mx.array(k_npy) - v_mlx = mx.array(v_npy) + reference = mlx_primitives_sdpa(q, k, v, scale) - reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale) + o = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) - o_mlx = mx.fast.scaled_dot_product_attention( - q_mlx, k_mlx, v_mlx, scale=scale, mask=None - ) - - self.assertListEqual(list(reference.shape), list(o_mlx.shape)) - self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4)) + self.assertListEqual(list(reference.shape), list(o.shape)) + self.assertTrue(mx.allclose(o, reference, atol=1e-4)) B = 1 H = 32 - dtypes = [np.float32] - if self.is_apple_silicon: - dtypes.append(np.half) - for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: - for DO_GQA in [0, 1]: - for DTYPE in dtypes: - n_kv_heads = 8 if DO_GQA else 32 - q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE) - k_npy = np.random.normal( - 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) - ).astype(DTYPE) - v_npy = np.random.normal( - 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) - ).astype(DTYPE) - - q_mlx = mx.array(q_npy) - k_mlx = mx.array(k_npy) - v_mlx = mx.array(v_npy) - - reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) - o_mlx = mx.fast.scaled_dot_product_attention( - q_mlx, k_mlx, v_mlx, scale=scale - ) - - self.assertListEqual(list(reference.shape), list(o_mlx.shape)) - rtol = 1e-5 - atol = 1e-1 - - if SEQUENCE_LENGTH > 500: - rtol = 1e-2 - - if DTYPE == np.half: - rtol = 1e-2 - - self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) + tests = product( + [1, 7, 9, 32, 63, 67, 129, 2000], # sequence length + [False, True], # gqa + [mx.float32, mx.float16], + [4, 8], # bits + ) + for sequence_length, do_gqa, dtype, bits in tests: + with self.subTest( + sequence_length=sequence_length, gqa=do_gqa, dtype=dtype, bits=bits + ): + n_kv_heads = 8 if do_gqa else 32 + q = mx.random.normal(shape=(B, H, R, Dk), dtype=dtype) + k = mx.random.normal( + shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype + ) + v = mx.random.normal( + shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype + ) + + k_q = mx.quantize(k, bits=bits) + v_q = mx.quantize(v, bits=bits) + k_d = mx.dequantize(*k_q, bits=bits) + v_d = mx.dequantize(*v_q, bits=bits) + + reference = mlx_primitives_sdpa_with_gqa(q, k_d, v_d, scale) + o = mx.fast.scaled_dot_product_attention(q, k_d, v_d, scale=scale) + o_q = mx.fast.quantized_scaled_dot_product_attention( + q, *k_q, *v_q, scale=scale, bits=bits + ) + + self.assertListEqual(list(reference.shape), list(o.shape)) + rtol = 1e-5 + atol = 1e-1 + + if sequence_length > 500: + rtol = 1e-2 + + if dtype == mx.float16: + rtol = 1e-2 + + # np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol) + self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol)) + self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol)) q = mx.random.normal(shape=(1, 32, 1, Dk)) k = mx.random.normal(shape=(1, 32, 32, Dk))