Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash_attention: support also cross attention. #8427

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,120 @@ def test_flash_attention_backward_segment_ids_spmd(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_cross_flash_attention_wrapper_segment_ids_spmd(self):
from torch_xla.experimental.custom_kernel import flash_attention
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds
xs.set_global_mesh(xs.get_1d_mesh("data"))

q = torch.randn(3, 2, 1024, 4)
k = torch.randn(3, 2, 128, 4)
v = torch.randn(3, 2, 128, 4)
zeros = torch.zeros(3, 32)
kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q_segment_ids = torch.ones(3, q.shape[2], dtype=torch.float32)
# only shard data dimension
o = flash_attention(
q.to("xla"),
k.to("xla"),
v.to("xla"),
False,
q_segment_ids.to("xla"),
kv_segment_ids.to("xla"),
partition_spec=("data", None, None, None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")

jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
jax_k = jnp.array(k.numpy(), dtype=jnp.float32)
jax_v = jnp.array(v.numpy(), dtype=jnp.float32)
jax_q_segment_ids = jnp.array(q_segment_ids.numpy(), dtype=jnp.float32)
jax_kv_segment_ids = jnp.array(kv_segment_ids.numpy(), dtype=jnp.float32)
expected_o = torch.from_numpy(
np.array(
jax_flash_attention(
jax_q,
jax_k,
jax_v,
segment_ids=SegmentIds(jax_q_segment_ids, jax_kv_segment_ids),
)))

self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_cross_flash_attention_backward_segment_ids_spmd(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.get_1d_mesh("data"))

torch.manual_seed(42)
q = torch.randn(4, 2, 1024, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q_segment_ids = torch.ones(4, q.shape[2], dtype=torch.float32).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = flash_attention(
q,
k,
v,
False,
q_segment_ids,
kv_segment_ids,
partition_spec=("data", None, None, None))
loss = o.sum()
loss.backward()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(q_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(k_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
torch_xla.sync()

torch.manual_seed(42)
q = torch.randn(4, 2, 1024, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q_segment_ids = torch.ones(4, q.shape[2], dtype=torch.float32).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = self._attention(
q,
k,
v,
attn_mask=self._make_attention_mask_from_segment_ids(
q_segment_ids, kv_segment_ids))
loss = o.sum()
loss.backward()
xm.mark_step()

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
25 changes: 15 additions & 10 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
ctx.sm_scale = sm_scale
ctx.partition_spec = partition_spec
ctx.mesh = mesh
ctx.full_shape = None
ctx.q_full_shape = None
ctx.kv_full_shape = None
save_residuals = q.requires_grad or k.requires_grad or v.requires_grad

# SPMD integration.
Expand All @@ -247,7 +248,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
full_v = v
full_ab = ab
if partition_spec is not None:
ctx.full_shape = q.shape
ctx.q_full_shape = q.shape
ctx.kv_full_shape = k.shape
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor
k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor
v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor
Expand Down Expand Up @@ -313,19 +315,21 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
# SPMD integration
if partition_spec is not None:
o = xs.disable_manual_sharding(
o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor
o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
return o
o, *aux = o
l, m = (v[..., 0] for v in aux[-2:])

# SPMD integration
if partition_spec is not None:
o = xs.disable_manual_sharding(
o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor
o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
l = xs.disable_manual_sharding(
l, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
l, partition_spec[0:3], ctx.q_full_shape[0:3],
mesh=mesh).global_tensor
m = xs.disable_manual_sharding(
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
m, partition_spec[0:3], ctx.q_full_shape[0:3],
mesh=mesh).global_tensor

# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
# but it should be OK as the backward will use the same partition_spec
Expand All @@ -342,7 +346,8 @@ def backward(ctx, grad_output):
sm_scale = ctx.sm_scale
partition_spec = ctx.partition_spec
mesh = ctx.mesh
full_shape = ctx.full_shape
q_full_shape = ctx.q_full_shape
kv_full_shape = ctx.kv_full_shape
# this segment_ids only reflects the local shape of segment_ids
segment_ids = ctx.segment_ids
grad_q = grad_k = grad_v = grad_ab = None
Expand Down Expand Up @@ -467,11 +472,11 @@ def backward(ctx, grad_output):
# SPMD integration
if partition_spec is not None:
grad_q = xs.disable_manual_sharding(
grad_q, partition_spec, full_shape, mesh=mesh).global_tensor
grad_q, partition_spec, q_full_shape, mesh=mesh).global_tensor
grad_k = xs.disable_manual_sharding(
grad_k, partition_spec, full_shape, mesh=mesh).global_tensor
grad_k, partition_spec, kv_full_shape, mesh=mesh).global_tensor
grad_v = xs.disable_manual_sharding(
grad_v, partition_spec, full_shape, mesh=mesh).global_tensor
grad_v, partition_spec, kv_full_shape, mesh=mesh).global_tensor

return grad_q, grad_k, grad_v, None, None, None, None, grad_ab, None, None

Expand Down