Skip to content

Commit

Permalink
Revert the default
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Mar 21, 2024
1 parent 4734d78 commit 766862b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def attention(q, k, v):
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

o = flash_attention(q, k, v, causal=False)
o = flash_attention(q, k, v)
expected_o = attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)
Expand All @@ -243,7 +243,7 @@ def attention(q, k, v):

# The causal mask is turned on by default in the wrapper.
# It masks out the top right triangle of the attention matrix, therefore it speeds up the compute but also changes the output.
o = flash_attention(q, k, v)
o = flash_attention(q, k, v, causal=True)
expected_o = attention(q, k, v)
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ def wrapped_kernel(kernel: Callable,


# This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# where we only takes q, k, v, and segment_ids as input and set causal and block_sizes for the users.
# where we only takes q, k, v, segment_ids and causal as input and set block_sizes for the users.
def flash_attention(
q, # [batch_size, num_heads, q_seq_len, d_model]
k, # [batch_size, num_heads, kv_seq_len, d_model]
v, # [batch_size, num_heads, kv_seq_len, d_model]
segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
causal=True,
causal=False,
):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
Expand Down

0 comments on commit 766862b

Please sign in to comment.