diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 291f87203..6e56fa1ef 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -1,58 +1,94 @@ -import argparse -import math - import mlx.core as mx +import numpy as np +from mlx.utils import tree_map from time_utils import time_fn -L = 16384 +L = 32768 H = 32 H_k = H // 4 D = 128 dtype = mx.float16 -loops = 10 +bits = 8 + +loops = 20 def attention(q, k, v): - def _sdpa(q, k, v): + for _ in range(loops): B, Hq, L, D = q.shape _, Hk, S, _ = k.shape q = q.reshape(B, Hk, Hq // Hk, L, D) - k = k[:, :, None, :, :] - v = v[:, :, None, :, :] - s = q @ k.transpose(0, 1, 2, 4, 3) + ke = k[:, :, None, :, :] + ve = v[:, :, None, :, :] + s = q @ ke.transpose(0, 1, 2, 4, 3) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) - o = p @ v - return o.reshape(B, Hq, L, D) - - for i in range(loops): - q = _sdpa(q, k, v) + q = p @ ve + q = q.reshape(B, Hq, L, D) return q def sdpa(q, k, v): - for i in range(loops): - q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + for _ in range(loops): + q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) return q -def time_self_attention_primitives(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) - k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - mx.eval(q, k, v) +def quant_sdpa(q, k, v, bits=4): + for _ in range(loops): + q = mx.fast.quantized_scaled_dot_product_attention( + q, *k, *v, scale=1.0, mask=None, bits=bits + ) + return q + + +def quant_attention(q, k, v, bits=4): + for _ in range(loops): + B, Hq, L, D = q.shape + Hk = k[0].shape[1] + + q = q.reshape((B, Hk, Hq // Hk, L, D)) + ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k) + ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v) + + scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits) + scores = mx.softmax(scores, axis=-1) + + q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits) + q = q.reshape((B, Hq, L, D)) + return q + + +def time_self_attention_primitives(q, k, v): time_fn(attention, q, k, v) -def time_self_attention_sdpa(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) - k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - mx.eval(q, k, v) +def time_self_attention_sdpa(q, k, v): time_fn(sdpa, q, k, v) +def time_self_attention_quant_sdpa(q, k, v, bits=4): + time_fn(quant_sdpa, q, k, v, bits) + + +def time_self_attention_quant_primitives(q, k, v, bits=4): + time_fn(quant_attention, q, k, v, bits) + + if __name__ == "__main__": - time_self_attention_sdpa() - time_self_attention_primitives() + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype) + k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) + v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype) + mx.eval(q, k, v) + + k_quant = mx.quantize(k, bits=bits) + v_quant = mx.quantize(v, bits=bits) + mx.eval(k_quant, v_quant) + + k = mx.dequantize(*k_quant, bits=bits) + v = mx.dequantize(*v_quant, bits=bits) + + time_self_attention_sdpa(q, k, v) + time_self_attention_quant_sdpa(q, k_quant, v_quant, bits) + time_self_attention_primitives(q, k, v) + time_self_attention_quant_primitives(q, k_quant, v_quant, bits) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index b5bc9607e..9beec77b1 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -20,4 +20,33 @@ using namespace metal; instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) + +// Quantized SDPA vector instantiations +#define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \ + instantiate_kernel( \ + #name "_" #type "_" #head_dim "_" #group_size "_" #bits, \ + name, type, head_dim, group_size, bits) + +#define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \ + instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits) \ + instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits) + +#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \ + instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4) \ + instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8) + +#define instantiate_quant_sdpa_vector_group_size(type, heads) \ + instantiate_quant_sdpa_vector_bits(type, heads, 32) \ + instantiate_quant_sdpa_vector_bits(type, heads, 64) \ + instantiate_quant_sdpa_vector_bits(type, heads, 128) + +#define instantiate_quant_sdpa_vector_heads(type) \ + instantiate_quant_sdpa_vector_group_size(type, 64) \ + instantiate_quant_sdpa_vector_group_size(type, 96) \ + instantiate_quant_sdpa_vector_group_size(type, 128) + +instantiate_quant_sdpa_vector_heads(float) +instantiate_quant_sdpa_vector_heads(bfloat16_t) +instantiate_quant_sdpa_vector_heads(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 8b6af638e..49eb35b1f 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -113,6 +113,208 @@ template } } +template +METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) { + U query_sum = 0; + if (bits == 4) { + for (int i = 0; i < elem_per_thread; i += 4) { + q[i] = scale * queries[i]; + q[i + 1] = scale * queries[i + 1]; + q[i + 2] = scale * queries[i + 2]; + q[i + 3] = scale * queries[i + 3]; + query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3]; + q[i + 1] /= 16.0f; + q[i + 2] /= 256.0f; + q[i + 3] /= 4096.0f; + } + } else if (bits == 8) { + for (int i = 0; i < elem_per_thread; i++) { + q[i] = scale * queries[i]; + query_sum += q[i]; + } + } + return query_sum; +} + +template +METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) { + if (bits == 4) { + auto ks = (const device uint16_t*)keys; + for (int i = 0; i < elem_per_thread / 4; i++) { + k[4 * i] = ks[i] & 0x000f; + k[4 * i + 1] = ks[i] & 0x00f0; + k[4 * i + 2] = ks[i] & 0x0f00; + k[4 * i + 3] = ks[i] & 0xf000; + } + } else if (bits == 8) { + auto ks = (const device uint8_t*)keys; + for (int i = 0; i < elem_per_thread; i++) { + k[i] = ks[i]; + } + } +} + +template +METAL_FUNC void load_values( + const device uint32_t* values, + thread U* v, + U value_scale, + U value_bias) { + auto vs = (const device uint8_t*)values; + if (bits == 4) { + U s[2] = {value_scale, value_scale / 16.0f}; + for (int i = 0; i < elem_per_thread / 2; i++) { + v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias; + v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias; + } + } else if (bits == 8) { + for (int i = 0; i < elem_per_thread; i++) { + v[i] = value_scale * vs[i] + value_bias; + } + } +} + +template +[[kernel]] void quant_sdpa_vector( + const device T* queries [[buffer(0)]], + const device uint32_t* keys [[buffer(1)]], + const device T* key_scales [[buffer(2)]], + const device T* key_biases [[buffer(3)]], + const device uint32_t* values [[buffer(4)]], + const device T* value_scales [[buffer(5)]], + const device T* value_biases [[buffer(6)]], + device T* out [[buffer(7)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& group_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int BN = 32; + constexpr int BD = 4; + constexpr int elem_per_thread = D / BD; + constexpr int pack_factor = 32 / bits; + + const int stride = BN * D; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U v[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + quad_lid * elem_per_thread; + + const int kv_idx = quad_gid * D + quad_lid * elem_per_thread; + const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor; + const int group_idx = kv_head_idx * group_stride + kv_idx / group_size; + keys += packed_idx; + key_scales += group_idx; + key_biases += group_idx; + values += packed_idx; + value_scales += group_idx; + value_biases += group_idx; + + out += head_idx * D + simd_gid * elem_per_thread; + + // Read the query and 0 the output accumulator + U query_sum = load_queries( + queries, q, static_cast(scale)); + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (int i = quad_gid; i < N; i += BN) { + load_keys(keys, k); + + // Assume D % group_size == 0 so all the keys are in the same group + U key_scale = key_scales[0]; + U key_bias = key_biases[0]; + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = score * key_scale + query_sum * key_bias; + score = quad_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + U value_scale = value_scales[0]; + U value_bias = value_biases[0]; + + // Load the values + load_values(values, v, value_scale, value_bias); + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * v[i]; + } + + // Move the pointers to the next kv + keys += stride / pack_factor; + key_scales += stride / group_size; + key_biases += stride / group_size; + values += stride / pack_factor; + value_scales += stride / group_size; + value_biases += stride / group_size; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + // Each quadgroup communicates it's max score + if (quad_lid == 0) { + max_scores[quad_gid] = max_score; + sum_exp_scores[quad_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + // 128 threads with 32 values per thread + outputs[simd_gid * BN + simd_lid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], @@ -290,3 +492,158 @@ template } } } + +template +[[kernel]] void quant_sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device uint32_t* keys [[buffer(1)]], + const device T* key_scales [[buffer(2)]], + const device T* key_biases [[buffer(3)]], + const device uint32_t* values [[buffer(4)]], + const device T* value_scales [[buffer(5)]], + const device T* value_biases [[buffer(6)]], + device float* out [[buffer(7)]], + device float* sums [[buffer(8)]], + device float* maxs [[buffer(9)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant size_t& k_group_stride, + const constant size_t& v_group_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int BN = 8; + constexpr int BD = 4; + constexpr int elem_per_thread = D / BD; + const int stride = BN * D; + constexpr int blocks = 32; + constexpr int pack_factor = 32 / bits; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U v[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int block_idx = tid.z; + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + quad_lid * elem_per_thread; + + const int kv_idx = + (block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread; + const int packed_idx = kv_idx / pack_factor; + const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size; + const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size; + + keys += kv_head_idx * k_stride + packed_idx; + key_scales += k_group_idx; + key_biases += k_group_idx; + values += kv_head_idx * v_stride + packed_idx; + value_scales += v_group_idx; + value_biases += v_group_idx; + + out += head_idx * blocks * D + block_idx * D + quad_lid * elem_per_thread; + sums += head_idx * blocks + block_idx; + maxs += head_idx * blocks + block_idx; + + // Read the query and 0 the output accumulator + U query_sum = load_queries( + queries, q, static_cast(scale)); + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -1e9; + U sum_exp_score = 0; + + // For each key + for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) { + // Read the key + load_keys(keys, k); + + // Assume D % group_size == 0 so all the keys are in the same group + U key_scale = key_scales[0]; + U key_bias = key_biases[0]; + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = score * key_scale + query_sum * key_bias; + score = quad_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + U value_scale = value_scales[0]; + U value_bias = value_biases[0]; + load_values(values, v, value_scale, value_bias); + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * v[i]; + } + + // Move the pointers to the next kv + keys += blocks * stride / pack_factor; + key_scales += blocks * stride / group_size; + key_biases += blocks * stride / group_size; + values += blocks * stride / pack_factor; + value_scales += blocks * stride / group_size; + value_biases += blocks * stride / group_size; + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (quad_lid == 0) { + max_scores[quad_gid] = max_score; + sum_exp_scores[quad_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[quad_lid * BN + quad_gid] = + o[i] * fast::exp(max_scores[quad_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (quad_gid == 0) { + U output = outputs[quad_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[quad_lid * BN + j]; + } + out[i] = static_cast(output); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index f600a4890..0f87e6027 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -242,6 +242,171 @@ void sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void quant_sdpa_vector( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& k_scales, + const array& k_biases, + const array& v, + const array& v_scales, + const array& v_biases, + array& out, + float scale, + int group_size, + int bits) { + // Set the kernel name + std::string kname; + kname.reserve(96); + kname += "quant_sdpa_vector_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(group_size); + kname += "_"; + kname += std::to_string(bits); + + // Compute the necessary sizes + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int B = q.shape(0) * q.shape(1); + size_t stride = k.strides()[1]; + size_t group_stride = k_scales.strides()[1]; + MTL::Size group_dims(128, 1, 1); + MTL::Size grid_dims(1, B, 1); + + // Get the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set its arguments + compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(k_scales, 2); + compute_encoder.set_input_array(k_biases, 3); + compute_encoder.set_input_array(v, 4); + compute_encoder.set_input_array(v_scales, 5); + compute_encoder.set_input_array(v_biases, 6); + compute_encoder.set_output_array(out, 7); + compute_encoder.set_bytes(&gqa_factor, sizeof(int), 8); + compute_encoder.set_bytes(&N, sizeof(int), 9); + compute_encoder.set_bytes(&stride, sizeof(size_t), 10); + compute_encoder.set_bytes(&group_stride, sizeof(size_t), 11); + compute_encoder.set_bytes(&scale, sizeof(float), 12); + + // Launch + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void quant_sdpa_vector_2pass( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& k_scales, + const array& k_biases, + const array& v, + const array& v_scales, + const array& v_biases, + array& out, + float scale, + int group_size, + int bits) { + // Set the kernel name + std::string kname; + kname.reserve(96); + kname += "quant_sdpa_vector_2pass_1_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(group_size); + kname += "_"; + kname += std::to_string(bits); + + // Compute the necessary sizes + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int blocks = 32; + int B = q.shape(0) * q.shape(1); + size_t k_stride = k.strides()[1]; + size_t v_stride = v.strides()[1]; + size_t k_group_stride = k_scales.strides()[1]; + size_t v_group_stride = v_scales.strides()[1]; + MTL::Size group_dims(8 * 4, 1, 1); + MTL::Size grid_dims(1, B, blocks); + + // Allocate the intermediates + std::vector intermediate_shape; + intermediate_shape.reserve(out.ndim() + 1); + intermediate_shape.insert( + intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); + intermediate_shape.push_back(blocks); + intermediate_shape.push_back(out.shape().back()); + array intermediate(intermediate_shape, float32, nullptr, {}); + intermediate_shape.pop_back(); + array sums(intermediate_shape, float32, nullptr, {}); + array maxs(std::move(intermediate_shape), float32, nullptr, {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + sums.set_data(allocator::malloc_or_wait(sums.nbytes())); + maxs.set_data(allocator::malloc_or_wait(maxs.nbytes())); + d.add_temporary(intermediate, s.index); + d.add_temporary(sums, s.index); + d.add_temporary(maxs, s.index); + + // Get the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set its arguments + compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(k_scales, 2); + compute_encoder.set_input_array(k_biases, 3); + compute_encoder.set_input_array(v, 4); + compute_encoder.set_input_array(v_scales, 5); + compute_encoder.set_input_array(v_biases, 6); + compute_encoder.set_output_array(intermediate, 7); + compute_encoder.set_output_array(sums, 8); + compute_encoder.set_output_array(maxs, 9); + compute_encoder.set_bytes(gqa_factor, 10); + compute_encoder.set_bytes(N, 11); + compute_encoder.set_bytes(k_stride, 12); + compute_encoder.set_bytes(v_stride, 13); + compute_encoder.set_bytes(k_group_stride, 14); + compute_encoder.set_bytes(v_group_stride, 15); + compute_encoder.set_bytes(scale, 16); + + // Launch + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Final pass + kname.clear(); + kname += "sdpa_vector_2pass_2_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + + // Get the kernel + kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set its arguments + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_input_array(sums, 1); + compute_encoder.set_input_array(maxs, 2); + compute_encoder.set_output_array(out, 3); + + // Launch + group_dims = MTL::Size(1024, 1, 1); + grid_dims = MTL::Size(1, B, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + } // namespace void ScaledDotProductAttention::eval_gpu( @@ -254,7 +419,6 @@ void ScaledDotProductAttention::eval_gpu( auto& q_pre = inputs[0]; auto& k_pre = inputs[1]; - auto& v_pre = inputs[2]; auto& o = out; std::vector copies; @@ -295,9 +459,7 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) == 1) { - const auto& q = copy_unless(is_contiguous, q_pre); - const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); - const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); + auto q = copy_unless(is_contiguous, q_pre); // Donate the query if possible if (q.is_donatable()) { @@ -306,20 +468,55 @@ void ScaledDotProductAttention::eval_gpu( o.set_data(allocator::malloc_or_wait(o.nbytes())); } - // We route to the 2 pass fused attention if - // - The device is large and the sequence length long - // - The sequence length is even longer and we have gqa - char devc = d.get_architecture().back(); - if ((devc == 'd' && k.shape(2) >= 1024) || - (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { - sdpa_vector_2pass(s, d, q, k, v, o, scale_); + if (quantized_) { + auto& k_scales_pre = inputs[2]; + auto& k_biases_pre = inputs[3]; + auto& v_pre = inputs[4]; + auto& v_scales_pre = inputs[5]; + auto& v_biases_pre = inputs[6]; + + auto k = copy_unless(is_contiguous_except_seq_len, k_pre); + auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre); + auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre); + auto v = copy_unless(is_contiguous_except_seq_len, v_pre); + auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre); + auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre); + + quant_sdpa_vector_2pass( + s, + d, + q, + k, + k_scales, + k_biases, + v, + v_scales, + v_biases, + o, + scale_, + group_size_, + bits_); } else { - sdpa_vector(s, d, q, k, v, o, scale_); + auto& k_pre = inputs[1]; + auto& v_pre = inputs[2]; + + const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); + const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); + + char devc = d.get_architecture().back(); + if ((devc == 'd' && k.shape(2) >= 1024) || + (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { + sdpa_vector_2pass(s, d, q, k, v, o, scale_); + } else { + sdpa_vector(s, d, q, k, v, o, scale_); + } } } // Full attention mode else { + auto& v_pre = inputs[2]; + const auto& q = copy_unless(is_matrix_contiguous, q_pre); const auto& k = copy_unless(is_matrix_contiguous, k_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 731912d69..37a6ec47b 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -664,7 +664,7 @@ array scaled_dot_product_attention( std::move(out_shape), final_type, std::make_shared( - stream, fallback, scale, false), + stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false), {q, k, v}); } @@ -678,7 +678,130 @@ array scaled_dot_product_attention( bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { const ScaledDotProductAttention& a_other = static_cast(other); - return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_; + return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ && + quantized_ == a_other.quantized_; +} + +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const array& key_biases, + const array& values, + const array& value_scales, + const array& value_biases, + const float scale, + const std::optional& mask, + const int group_size, + const int bits, + StreamOrDevice s) { + int el_per_int = 32 / bits; + int out_dim = values.shape(-1) * el_per_int; + + auto n_q_heads = queries.shape(-3); + auto n_kv_heads = keys.shape(-3); + + auto out_shape = std::vector( + {queries.shape(0), queries.shape(1), queries.shape(2), out_dim}); + auto stream = to_stream(s); + bool needs_mask = mask.has_value(); + auto fallback = + [scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s]( + const std::vector& inputs) -> std::vector { + int n_repeats = n_q_heads / n_kv_heads; + + auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); + + auto k = inputs[1]; + auto k_scales = inputs[2]; + auto k_biases = inputs[3]; + auto v = inputs[4]; + auto v_scales = inputs[5]; + auto v_biases = inputs[6]; + + int B = q.shape(0); + int L = q.shape(2); + + if (n_repeats > 1) { + q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s); + k = expand_dims(k, 2, s); + k_scales = expand_dims(k_scales, 2, s); + k_biases = expand_dims(k_biases, 2, s); + v = expand_dims(v, 2, s); + v_scales = expand_dims(v_scales, 2, s); + v_biases = expand_dims(v_biases, 2, s); + } + + array scores = quantized_matmul( + q, + k, + k_scales, + k_biases, + /*transpose=*/true, + /*group_size=*/group_size, + /*bits=*/bits, + s); + if (needs_mask) { + scores = add(scores, inputs[7], s); + } + scores = softmax(scores, std::vector{-1}, true, s); + array out = quantized_matmul( + scores, + v, + v_scales, + v_biases, + /*transpose=*/false, + /*group_size=*/group_size, + /*bits=*/bits, + s); + if (n_repeats > 1) { + out = reshape(out, {B, n_q_heads, L, -1}, s); + } + return std::vector{out}; + }; + + int L = queries.shape(2); + if (L > 1) { + if (needs_mask) { + return fallback( + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases, + mask.value()})[0]; + } else { + return fallback( + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases})[0]; + } + } else { + return array( + std::move(out_shape), + queries.dtype(), + std::make_shared( + stream, + fallback, + scale, + /*needs_mask=*/false, + /*quantized=*/true, + group_size, + bits), + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases}); + } } array pack_and_quantize( diff --git a/mlx/fast.h b/mlx/fast.h index ddc3512b5..dc47e6c46 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -41,6 +41,21 @@ array scaled_dot_product_attention( const std::optional memory_efficient_threshold = std::nullopt, StreamOrDevice s = {}); +/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/ +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const array& key_biases, + const array& values, + const array& value_scales, + const array& value_biases, + const float scale, + const std::optional& mask = std::nullopt, + const int group_size = 64, + const int bits = 4, + StreamOrDevice s = {}); + std::tuple affine_quantize( const array& w, int group_size = 64, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 30db282ff..830da18b4 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -190,8 +190,16 @@ class ScaledDotProductAttention : public Custom { Stream stream, std::function(std::vector)> fallback, const float scale, - const bool needs_mask) - : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {} + const bool needs_mask, + const bool quantized, + const int group_size = 64, + const int bits = 4) + : Custom(stream, fallback), + scale_(scale), + needs_mask_(needs_mask), + quantized_(quantized), + group_size_(group_size), + bits_(bits) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -212,6 +220,9 @@ class ScaledDotProductAttention : public Custom { std::function(std::vector)> fallback_; float scale_; bool needs_mask_; + bool quantized_; + int group_size_; + int bits_; }; class AffineQuantize : public Custom { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index cbc8b934d..188a93b22 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -161,6 +161,45 @@ void init_fast(nb::module_& parent_module) { array: The output array. )pbdoc"); + m.def( + "quantized_scaled_dot_product_attention", + &fast::quantized_scaled_dot_product_attention, + "q"_a, + "k"_a, + "k_scales"_a, + "k_biases"_a, + "v"_a, + "v_scales"_a, + "v_biases"_a, + nb::kw_only(), + "scale"_a, + "mask"_a = nb::none(), + "group_size"_a = 64, + "bits"_a = 4, + "stream"_a = nb::none(), + nb::sig( + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + A fast implementation of multi-head attention where the keys and values are quantized. + + see :func:`scaled_dot_product_attention` for more details. + + Args: + q (array): Input query array. + k (array): Input keys array. + k_scales (array): Scales for the quantized keys array. + k_biases (array): Biases for the quantized keys array. + v (array): Input values array. + v_scales (array): Scales for the quantized values array. + v_biases (array): Biases for the quantized values array. + scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) + mask (array, optional): An additive mask to apply to the query-key scores. + group_size (int): The group size used in the KV quantization. + bits (int): The bits used in the KV quantization. + Returns: + array: The output array. + )pbdoc"); + m.def( "metal_kernel", [](const std::string& name,