diff --git a/test/test_pallas.py b/test/test_pallas.py index 8901a84c80a..4c755f5db2e 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -643,24 +643,23 @@ def test_flash_attention_wrapper_segment_ids_1(self): q = torch.randn(3, 2, 128, 4) k = torch.randn(3, 2, 128, 4) v = torch.randn(3, 2, 128, 4) - q_segment_ids = torch.zeros(3, 128) - kv_segment_ids = torch.zeros(3, 128) + zeros = torch.zeros(3, 32) + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) o = flash_attention( - q.to("xla"), k.to("xla"), v.to("xla"), False, q_segment_ids.to("xla"), - kv_segment_ids.to("xla")) + q.to("xla"), k.to("xla"), v.to("xla"), False, segment_ids.to("xla"), + segment_ids.to("xla")) jax_q = jnp.array(q.numpy(), dtype=jnp.float32) jax_k = jnp.array(k.numpy(), dtype=jnp.float32) jax_v = jnp.array(v.numpy(), dtype=jnp.float32) - jax_q_segment_ids = jnp.array(q_segment_ids.numpy(), dtype=jnp.float32) - jax_kv_segment_ids = jnp.array(kv_segment_ids.numpy(), dtype=jnp.float32) + jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32) expected_o = torch.from_numpy( np.array( jax_flash_attention( jax_q, jax_k, jax_v, - segment_ids=SegmentIds(jax_q_segment_ids, jax_kv_segment_ids), + segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids), ))) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) @@ -674,16 +673,16 @@ def test_flash_attention_wrapper_segment_ids_2(self): q = torch.randn(3, 2, 128, 4).to("xla") k = torch.randn(3, 2, 128, 4).to("xla") v = torch.randn(3, 2, 128, 4).to("xla") - q_segment_ids = torch.zeros(3, 128).to("xla") - kv_segment_ids = torch.zeros(3, 128).to("xla") - o = flash_attention(q, k, v, False, q_segment_ids, kv_segment_ids) + zeros = torch.zeros(3, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + o = flash_attention(q, k, v, False, segment_ids, segment_ids) expected_o = self._attention( q, k, v, attn_mask=self._make_attention_mask_from_segment_ids( - q_segment_ids, kv_segment_ids)) + segment_ids, segment_ids)) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) @@ -697,13 +696,13 @@ def test_flash_attention_backward_segment_ids(self): q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - q_segment_ids = torch.zeros(4, 128).to("xla") - kv_segment_ids = torch.zeros(4, 128).to("xla") + zeros = torch.zeros(4, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) q.retain_grad() k.retain_grad() v.retain_grad() - o = flash_attention(q, k, v, False, q_segment_ids, kv_segment_ids) + o = flash_attention(q, k, v, False, segment_ids, segment_ids) loss = o.sum() loss.backward() xm.mark_step() @@ -716,8 +715,8 @@ def test_flash_attention_backward_segment_ids(self): q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - q_segment_ids = torch.zeros(4, 128).to("xla") - kv_segment_ids = torch.zeros(4, 128).to("xla") + zeros = torch.zeros(4, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) q.retain_grad() k.retain_grad() v.retain_grad() @@ -727,7 +726,7 @@ def test_flash_attention_backward_segment_ids(self): k, v, attn_mask=self._make_attention_mask_from_segment_ids( - q_segment_ids, kv_segment_ids)) + segment_ids, segment_ids)) loss = o.sum() loss.backward() xm.mark_step()