Skip to content

Commit

Permalink
Sync upstream mlx sdpa vector kernels with mask (#66)
Browse files Browse the repository at this point in the history
* Sync upstream mlx sdpa vector kernels with mask

* Format
  • Loading branch information
EricLBuehler authored Jan 18, 2025
1 parent 3388a3c commit b33d200
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 86 deletions.
15 changes: 13 additions & 2 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2013,7 +2013,12 @@ pub fn call_sdpa_vector(
alpha
};

let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let constants = Some(ConstantValues::new(vec![(
20,
Value::Bool(/* sdpa_vector_has_mask */ false),
)]));

let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, &name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
Expand Down Expand Up @@ -2125,7 +2130,13 @@ pub fn call_sdpa_vector_2pass(
alpha
};

let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass1)?;
let constants = Some(ConstantValues::new(vec![(
20,
Value::Bool(/* sdpa_vector_has_mask */ false),
)]));

let pipeline =
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
Expand Down
161 changes: 77 additions & 84 deletions candle-metal-kernels/src/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ struct MLXScaledDotProductAttentionParams {

// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector"

constant bool sdpa_vector_has_mask [[function_constant(20)]];

template <typename T, int D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
Expand All @@ -311,14 +313,16 @@ template <typename T, int D>
const constant size_t& v_stride,
const constant float& scale,
const constant float& softcapping,
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]],
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;

const int stride = BN * D;
constexpr int stride = BN * D;

typedef float U;

Expand All @@ -336,6 +340,9 @@ template <typename T, int D>
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
if (sdpa_vector_has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
}
out += head_idx * D + simd_gid * elem_per_thread;

// Read the query and 0 the output accumulator
Expand All @@ -351,38 +358,43 @@ template <typename T, int D>

// For each key
for (int i = simd_gid; i < N; i += BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
if (!sdpa_vector_has_mask || mask[0]) {
// Read the key
for (int j = 0; j < elem_per_thread; j++) {
k[j] = keys[j];
}

// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
if (softcapping != 1.) {
score = precise::tanh(score);
score = score * softcapping;
}
// Compute the i-th score
U score = 0;
for (int j = 0; j < elem_per_thread; j++) {
score += q[j] * k[j];
}
score = simd_sum(score);
if (softcapping != 1.) {
score = precise::tanh(score);
score = score * softcapping;
}

// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);

max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;

// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
// Update the output accumulator
for (int j = 0; j < elem_per_thread; j++) {
o[j] = o[j] * factor + exp_score * values[j];
}
}

// Move the pointers to the next kv
keys += stride;
values += stride;
if (sdpa_vector_has_mask) {
mask += BN * mask_seq_stride;
}
}

// Each thread has a partial part of the output so we need to combine them.
Expand Down Expand Up @@ -428,6 +440,9 @@ template <typename T, int D>
const constant size_t& v_stride,
const constant float& scale,
const constant float& softcapping,
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]],
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
Expand Down Expand Up @@ -457,6 +472,10 @@ template <typename T, int D>
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
if (sdpa_vector_has_mask) {
mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride;
}
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;

Expand All @@ -473,75 +492,43 @@ template <typename T, int D>

// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
if (!sdpa_vector_has_mask || mask[0]) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}

// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
if (softcapping != 1.) {
score = precise::tanh(score);
score = score * softcapping;
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
if (softcapping != 1.) {
score = precise::tanh(score);
score = score * softcapping;
}

// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);

max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;

// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
}

// Move the pointers to the next kv
keys += blocks * stride;
values += blocks * stride;
}

// Each thread has a partial part of the output so we need to combine them.

// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);

// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}

// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);

// And write the output
if (simd_gid == 0) {
U output = outputs[simd_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[simd_lid * BN + j];
}
out[i] = static_cast<T>(output);
if (sdpa_vector_has_mask) {
mask += BN * blocks * mask_seq_stride;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}

Expand Down Expand Up @@ -1669,6 +1656,9 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
const constant size_t& v_stride, \
const constant float& scale, \
const constant float& softcapping, \
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); \
Expand All @@ -1686,6 +1676,9 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
const constant size_t& v_stride, \
const constant float& scale, \
const constant float& softcapping, \
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); \
Expand Down

0 comments on commit b33d200

Please sign in to comment.