Skip to content

Commit

Permalink
Cache flash attention tracing (#8026)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Sep 17, 2024
1 parent 03374cd commit 8b93611
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ def trace_pallas(kernel: Callable,
global trace_pallas_arg_to_payload
# implcit assumption here that everything in kwargs is hashable and not a tensor,
# which is true for the gmm and tgmm.
hash_key = (kernel, static_argnums, tuple(static_argnames), tuple(jax_args),
repr(sorted(kwargs.items())).encode())
hash_key = (jax.config.jax_default_matmul_precision, kernel, static_argnums,
tuple(static_argnames)
if static_argnames is not None else static_argnames,
tuple(jax_args), repr(sorted(kwargs.items())).encode())
if hash_key in trace_pallas_arg_to_payload:
torch_xla._XLAC._xla_increment_counter('trace_pallas_cache_hit', 1)
return trace_pallas_arg_to_payload[hash_key], tensor_args
Expand Down Expand Up @@ -287,7 +289,9 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]),
False,
static_argnums=range(5, 13))
static_argnums=range(5, 13),
use_cache=True,
)

args = [q, k, v]
if ab is not None:
Expand Down Expand Up @@ -386,7 +390,9 @@ def backward(ctx, grad_output):
static_argnames=[
"block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
"mask_value", "debug"
])
],
use_cache=True,
)

args = [q, k, v]
if ab is not None:
Expand Down Expand Up @@ -435,7 +441,8 @@ def backward(ctx, grad_output):
static_argnames=[
"block_q_major", "block_k_major", "block_k", "block_q",
"sm_scale", "causal", "mask_value", "debug"
])
],
use_cache=True)

grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
[k.shape, v.shape],
Expand Down

0 comments on commit 8b93611

Please sign in to comment.