Skip to content

Commit

Permalink
custom_kernel: fix shape mismatch by sharding segment_ids in flash attn.
Browse files Browse the repository at this point in the history
when adding the sharding support in this module, seqment_ids weren't
take into count which causes a failure with shape mismatch when using
them in sharded flash attention.
  • Loading branch information
dudulightricks committed Nov 20, 2024
1 parent 2ec2264 commit 0c79164
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
full_k = k
full_v = v
full_ab = ab
_, full_q_segment_ids, full_kv_segment_ids = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)

if partition_spec is not None:
ctx.full_shape = q.shape
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
Expand All @@ -254,6 +257,14 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
if ab:
ab = xs.enable_manual_sharding(
ab, partition_spec, mesh=mesh).global_tensor
if q_segment_ids is not None:
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, partition_spec[:q_segment_ids.ndim],
mesh=mesh).global_tensor
if kv_segment_ids is not None:
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
mesh=mesh).global_tensor

# It computes the shape and type of o, l, m.
shapes = [q.shape]
Expand Down Expand Up @@ -319,8 +330,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
m = xs.disable_manual_sharding(
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor

ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
kv_segment_ids, full_ab)
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, full_q_segment_ids,
full_kv_segment_ids, full_ab)
return o

@staticmethod
Expand Down Expand Up @@ -363,6 +374,14 @@ def backward(ctx, grad_output):
if ab:
ab = xs.enable_manual_sharding(
ab, partition_spec, mesh=mesh).global_tensor
if q_segment_ids is not None:
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, partition_spec[:q_segment_ids.ndim],
mesh=mesh).global_tensor
if kv_segment_ids is not None:
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
mesh=mesh).global_tensor

if ctx.needs_input_grad[0]:
payload, _ = trace_pallas(
Expand Down

0 comments on commit 0c79164

Please sign in to comment.