Skip to content

Commit

Permalink
flash_attention: support also cross attention.
Browse files Browse the repository at this point in the history
In case that q and kv have different shapes (cross attention) flash
attention with spmd fails since it does not support it.
  • Loading branch information
dudulightricks committed Dec 1, 2024
1 parent 1c91219 commit 840e3fe
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 10 deletions.
115 changes: 115 additions & 0 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,121 @@ def test_flash_attention_backward_segment_ids_spmd(self):
jax.config.update("jax_default_matmul_precision", "default")


@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_cross_flash_attention_wrapper_segment_ids_spmd(self):
from torch_xla.experimental.custom_kernel import flash_attention
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds
xs.set_global_mesh(xs.get_1d_mesh("data"))

q = torch.randn(3, 2, 1024, 4)
k = torch.randn(3, 2, 128, 4)
v = torch.randn(3, 2, 128, 4)
zeros = torch.zeros(3, 32)
kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q_segment_ids = torch.ones(3, q.shape[2], dtype=torch.float32)
# only shard data dimension
o = flash_attention(
q.to("xla"),
k.to("xla"),
v.to("xla"),
False,
q_segment_ids.to("xla"),
kv_segment_ids.to("xla"),
partition_spec=("data", None, None, None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")

jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
jax_k = jnp.array(k.numpy(), dtype=jnp.float32)
jax_v = jnp.array(v.numpy(), dtype=jnp.float32)
jax_q_segment_ids = jnp.array(q_segment_ids.numpy(), dtype=jnp.float32)
jax_kv_segment_ids = jnp.array(kv_segment_ids.numpy(), dtype=jnp.float32)
expected_o = torch.from_numpy(
np.array(
jax_flash_attention(
jax_q,
jax_k,
jax_v,
segment_ids=SegmentIds(jax_q_segment_ids, jax_kv_segment_ids),
)))

self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_cross_flash_attention_backward_segment_ids_spmd(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.get_1d_mesh("data"))

torch.manual_seed(42)
q = torch.randn(4, 2, 1024, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q_segment_ids = torch.ones(3, q.shape[2], dtype=torch.float32).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = flash_attention(
q,
k,
v,
False,
q_segment_ids,
kv_segment_ids,
partition_spec=("data", None, None, None))
loss = o.sum()
loss.backward()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(q_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(k_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
torch_xla.sync()

torch.manual_seed(42)
q = torch.randn(4, 2, 1024, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q_segment_ids = torch.ones(3, q.shape[2], dtype=torch.float32).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = self._attention(
q,
k,
v,
attn_mask=self._make_attention_mask_from_segment_ids(
q_segment_ids, kv_segment_ids))
loss = o.sum()
loss.backward()
xm.mark_step()

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
Expand Down
25 changes: 15 additions & 10 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
ctx.sm_scale = sm_scale
ctx.partition_spec = partition_spec
ctx.mesh = mesh
ctx.full_shape = None
ctx.q_full_shape = None
ctx.kv_full_shape = None
save_residuals = q.requires_grad or k.requires_grad or v.requires_grad

# SPMD integration.
Expand All @@ -247,7 +248,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
full_v = v
full_ab = ab
if partition_spec is not None:
ctx.full_shape = q.shape
ctx.q_full_shape = q.shape
ctx.kv_full_shape = k.shape
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
Expand Down Expand Up @@ -313,19 +315,21 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
# SPMD integration
if partition_spec is not None:
o = xs.disable_manual_sharding(
o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor
o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
return o
o, *aux = o
l, m = (v[..., 0] for v in aux[-2:])

# SPMD integration
if partition_spec is not None:
o = xs.disable_manual_sharding(
o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor
o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
l = xs.disable_manual_sharding(
l, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
l, partition_spec[0:3], ctx.q_full_shape[0:3],
mesh=mesh).global_tensor
m = xs.disable_manual_sharding(
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
m, partition_spec[0:3], ctx.q_full_shape[0:3],
mesh=mesh).global_tensor

# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
# but it should be OK as the backward will use the same partition_spec
Expand All @@ -342,7 +346,8 @@ def backward(ctx, grad_output):
sm_scale = ctx.sm_scale
partition_spec = ctx.partition_spec
mesh = ctx.mesh
full_shape = ctx.full_shape
q_full_shape = ctx.q_full_shape
kv_full_shape = ctx.kv_full_shape
# this segment_ids only reflects the local shape of segment_ids
segment_ids = ctx.segment_ids
grad_q = grad_k = grad_v = grad_ab = None
Expand Down Expand Up @@ -467,11 +472,11 @@ def backward(ctx, grad_output):
# SPMD integration
if partition_spec is not None:
grad_q = xs.disable_manual_sharding(
grad_q, partition_spec, full_shape, mesh=mesh).global_tensor
grad_q, partition_spec, q_full_shape, mesh=mesh).global_tensor
grad_k = xs.disable_manual_sharding(
grad_k, partition_spec, full_shape, mesh=mesh).global_tensor
grad_k, partition_spec, kv_full_shape, mesh=mesh).global_tensor
grad_v = xs.disable_manual_sharding(
grad_v, partition_spec, full_shape, mesh=mesh).global_tensor
grad_v, partition_spec, kv_full_shape, mesh=mesh).global_tensor

return grad_q, grad_k, grad_v, None, None, None, None, grad_ab, None, None

Expand Down

0 comments on commit 840e3fe

Please sign in to comment.