diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index 713def2b8b1..e4ae35387bd 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -227,6 +227,121 @@ def test_flash_attention_backward_segment_ids_spmd(self): jax.config.update("jax_default_matmul_precision", "default") + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_cross_flash_attention_wrapper_segment_ids_spmd(self): + from torch_xla.experimental.custom_kernel import flash_attention + from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds + xs.set_global_mesh(xs.get_1d_mesh("data")) + + q = torch.randn(3, 2, 1024, 4) + k = torch.randn(3, 2, 128, 4) + v = torch.randn(3, 2, 128, 4) + zeros = torch.zeros(3, 32) + kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q_segment_ids = torch.ones(3, q.shape[2], dtype=torch.float32) + # only shard data dimension + o = flash_attention( + q.to("xla"), + k.to("xla"), + v.to("xla"), + False, + q_segment_ids.to("xla"), + kv_segment_ids.to("xla"), + partition_spec=("data", None, None, None)) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}") + + jax_q = jnp.array(q.numpy(), dtype=jnp.float32) + jax_k = jnp.array(k.numpy(), dtype=jnp.float32) + jax_v = jnp.array(v.numpy(), dtype=jnp.float32) + jax_q_segment_ids = jnp.array(q_segment_ids.numpy(), dtype=jnp.float32) + jax_kv_segment_ids = jnp.array(kv_segment_ids.numpy(), dtype=jnp.float32) + expected_o = torch.from_numpy( + np.array( + jax_flash_attention( + jax_q, + jax_k, + jax_v, + segment_ids=SegmentIds(jax_q_segment_ids, jax_kv_segment_ids), + ))) + + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', "default") + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_cross_flash_attention_backward_segment_ids_spmd(self): + jax.config.update("jax_default_matmul_precision", "highest") + from torch_xla.experimental.custom_kernel import flash_attention + n_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.get_1d_mesh("data")) + + torch.manual_seed(42) + q = torch.randn(4, 2, 1024, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + zeros = torch.zeros(4, 32).to("xla") + kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q_segment_ids = torch.ones(4, q.shape[2], dtype=torch.float32).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention( + q, + k, + v, + False, + q_segment_ids, + kv_segment_ids, + partition_spec=("data", None, None, None)) + loss = o.sum() + loss.backward() + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(q_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(k_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(v_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + torch_xla.sync() + + torch.manual_seed(42) + q = torch.randn(4, 2, 1024, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + zeros = torch.zeros(4, 32).to("xla") + kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q_segment_ids = torch.ones(4, q.shape[2], dtype=torch.float32).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention( + q, + k, + v, + attn_mask=self._make_attention_mask_from_segment_ids( + q_segment_ids, kv_segment_ids)) + loss = o.sum() + loss.backward() + xm.mark_step() + + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update("jax_default_matmul_precision", "default") + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) torch.set_default_dtype(torch.float32) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 5e30ffba26a..733ee8cc8c3 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -237,7 +237,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, ctx.sm_scale = sm_scale ctx.partition_spec = partition_spec ctx.mesh = mesh - ctx.full_shape = None + ctx.q_full_shape = None + ctx.kv_full_shape = None save_residuals = q.requires_grad or k.requires_grad or v.requires_grad # SPMD integration. @@ -247,7 +248,8 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, full_v = v full_ab = ab if partition_spec is not None: - ctx.full_shape = q.shape + ctx.q_full_shape = q.shape + ctx.kv_full_shape = k.shape q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor @@ -313,7 +315,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, # SPMD integration if partition_spec is not None: o = xs.disable_manual_sharding( - o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor + o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor return o o, *aux = o l, m = (v[..., 0] for v in aux[-2:]) @@ -321,11 +323,13 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, # SPMD integration if partition_spec is not None: o = xs.disable_manual_sharding( - o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor + o, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor l = xs.disable_manual_sharding( - l, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor + l, partition_spec[0:3], ctx.q_full_shape[0:3], + mesh=mesh).global_tensor m = xs.disable_manual_sharding( - m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor + m, partition_spec[0:3], ctx.q_full_shape[0:3], + mesh=mesh).global_tensor # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided # but it should be OK as the backward will use the same partition_spec @@ -342,7 +346,8 @@ def backward(ctx, grad_output): sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec mesh = ctx.mesh - full_shape = ctx.full_shape + q_full_shape = ctx.q_full_shape + kv_full_shape = ctx.kv_full_shape # this segment_ids only reflects the local shape of segment_ids segment_ids = ctx.segment_ids grad_q = grad_k = grad_v = grad_ab = None @@ -467,11 +472,11 @@ def backward(ctx, grad_output): # SPMD integration if partition_spec is not None: grad_q = xs.disable_manual_sharding( - grad_q, partition_spec, full_shape, mesh=mesh).global_tensor + grad_q, partition_spec, q_full_shape, mesh=mesh).global_tensor grad_k = xs.disable_manual_sharding( - grad_k, partition_spec, full_shape, mesh=mesh).global_tensor + grad_k, partition_spec, kv_full_shape, mesh=mesh).global_tensor grad_v = xs.disable_manual_sharding( - grad_v, partition_spec, full_shape, mesh=mesh).global_tensor + grad_v, partition_spec, kv_full_shape, mesh=mesh).global_tensor return grad_q, grad_k, grad_v, None, None, None, None, grad_ab, None, None