Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dudulightricks committed Nov 21, 2024
1 parent 0608900 commit 45c5201
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, partition_spec[:q_segment_ids.ndim],
mesh=mesh).global_tensor
if kv_segment_ids is not None:
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
mesh=mesh).global_tensor
# if kv_segment_ids is not None:
# kv_segment_ids = xs.enable_manual_sharding(
# kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
# mesh=mesh).global_tensor

# It computes the shape and type of o, l, m.
shapes = [q.shape]
Expand Down Expand Up @@ -378,10 +378,10 @@ def backward(ctx, grad_output):
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, partition_spec[:q_segment_ids.ndim],
mesh=mesh).global_tensor
if kv_segment_ids is not None:
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
mesh=mesh).global_tensor
# if kv_segment_ids is not None:
# kv_segment_ids = xs.enable_manual_sharding(
# kv_segment_ids, partition_spec[:kv_segment_ids.ndim],
# mesh=mesh).global_tensor

if ctx.needs_input_grad[0]:
payload, _ = trace_pallas(
Expand Down

0 comments on commit 45c5201

Please sign in to comment.