diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index d1b47485f7e..9d8f27183d5 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -261,10 +261,10 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, q_segment_ids = xs.enable_manual_sharding( q_segment_ids, partition_spec[:q_segment_ids.ndim], mesh=mesh).global_tensor - if kv_segment_ids is not None: - kv_segment_ids = xs.enable_manual_sharding( - kv_segment_ids, partition_spec[:kv_segment_ids.ndim], - mesh=mesh).global_tensor + # if kv_segment_ids is not None: + # kv_segment_ids = xs.enable_manual_sharding( + # kv_segment_ids, partition_spec[:kv_segment_ids.ndim], + # mesh=mesh).global_tensor # It computes the shape and type of o, l, m. shapes = [q.shape] @@ -378,10 +378,10 @@ def backward(ctx, grad_output): q_segment_ids = xs.enable_manual_sharding( q_segment_ids, partition_spec[:q_segment_ids.ndim], mesh=mesh).global_tensor - if kv_segment_ids is not None: - kv_segment_ids = xs.enable_manual_sharding( - kv_segment_ids, partition_spec[:kv_segment_ids.ndim], - mesh=mesh).global_tensor + # if kv_segment_ids is not None: + # kv_segment_ids = xs.enable_manual_sharding( + # kv_segment_ids, partition_spec[:kv_segment_ids.ndim], + # mesh=mesh).global_tensor if ctx.needs_input_grad[0]: payload, _ = trace_pallas(