diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index fdc5992c3b0..d1b47485f7e 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -246,6 +246,9 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, full_k = k full_v = v full_ab = ab + _, full_q_segment_ids, full_kv_segment_ids = FlashAttention.prepare_segment_ids( + q_segment_ids, kv_segment_ids) + if partition_spec is not None: ctx.full_shape = q.shape q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor @@ -254,6 +257,14 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, if ab: ab = xs.enable_manual_sharding( ab, partition_spec, mesh=mesh).global_tensor + if q_segment_ids is not None: + q_segment_ids = xs.enable_manual_sharding( + q_segment_ids, partition_spec[:q_segment_ids.ndim], + mesh=mesh).global_tensor + if kv_segment_ids is not None: + kv_segment_ids = xs.enable_manual_sharding( + kv_segment_ids, partition_spec[:kv_segment_ids.ndim], + mesh=mesh).global_tensor # It computes the shape and type of o, l, m. shapes = [q.shape] @@ -319,8 +330,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, m = xs.disable_manual_sharding( m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor - ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids, - kv_segment_ids, full_ab) + ctx.save_for_backward(full_q, full_k, full_v, o, l, m, full_q_segment_ids, + full_kv_segment_ids, full_ab) return o @staticmethod @@ -363,6 +374,14 @@ def backward(ctx, grad_output): if ab: ab = xs.enable_manual_sharding( ab, partition_spec, mesh=mesh).global_tensor + if q_segment_ids is not None: + q_segment_ids = xs.enable_manual_sharding( + q_segment_ids, partition_spec[:q_segment_ids.ndim], + mesh=mesh).global_tensor + if kv_segment_ids is not None: + kv_segment_ids = xs.enable_manual_sharding( + kv_segment_ids, partition_spec[:kv_segment_ids.ndim], + mesh=mesh).global_tensor if ctx.needs_input_grad[0]: payload, _ = trace_pallas(