From b6ea085bd1e5b386b992e0716e06d77f4c647206 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 14 Jun 2024 10:44:59 -0500 Subject: [PATCH] clean up --- flash_attn/flash_attn_triton_interface_amd.py | 74 ++++++++++--------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/flash_attn/flash_attn_triton_interface_amd.py b/flash_attn/flash_attn_triton_interface_amd.py index d27619ac3..2ca749968 100644 --- a/flash_attn/flash_attn_triton_interface_amd.py +++ b/flash_attn/flash_attn_triton_interface_amd.py @@ -2,10 +2,8 @@ import torch import triton -# /////////////////////////////////////////// Interface ////////////////////////////////////////////////////////// DEBUG=False - def fwd(q, k, v, @@ -35,32 +33,34 @@ def fwd(q, if dropout_p != 0.0: raise ValueError("dropout is not supported on HIP") - if o is None: o = torch.empty_like(q) + # Setup metadata + input_metadata = MetaData(sm_scale=softmax_scale) + input_metadata.max_seqlens_q = q.shape[1] + input_metadata.max_seqlens_k = k.shape[1] + input_metadata.layout = "bshd" - # Create metadata object - metadata = MetaData(sm_scale=softmax_scale) - metadata.max_seqlens_q = q.shape[1] - metadata.max_seqlens_k = k.shape[1] - metadata.layout = "bshd" + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata) - # Setup metadata if causal: - metadata.need_causal() + input_metadata.need_causal() + # if bias is not None: - # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) + # input_metadata.need_bias(bias, batch, nheads_q, input_metadata.max_seqlens_q, input_metadata.max_seqlens_k) + if alibi_slopes is not None: - metadata.need_alibi(alibi_slopes, q.shape[0], q.shape[2]) + input_metadata.need_alibi(alibi_slopes, batch, nheads_q) + if dropout_p > 0.0: - metadata.need_dropout(dropout_p, return_softmax) + input_metadata.need_dropout(dropout_p, return_softmax) # Check arguments - metadata.check_args(q, k, v, o) + input_metadata.check_args(q, k, v, o) # Perform the forward attention computation - tri_out, encoded_softmax = attention(q, k, v, o, metadata) + tri_out, encoded_softmax = attention(q, k, v, o, input_metadata) softmax_lse = encoded_softmax softmax_p = encoded_softmax @@ -96,28 +96,26 @@ def varlen_fwd( if dropout_p != 0.0: raise ValueError("dropout is not supported on HIP") - - if o is None: o = torch.empty_like(q) - - - # create metadata object + # Setup metadata input_metadata = MetaData(sm_scale=softmax_scale) input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # get shapes batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata) - # Setup metadata if causal: input_metadata.need_causal() + # if bias is not None: - # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) + # input_metadata.need_bias(bias, batch, nheads_q, q.shape[2], k.shape[2]) + if alibi_slopes is not None: input_metadata.need_alibi(alibi_slopes, batch, nheads_q) + if dropout_p > 0.0: input_metadata.need_dropout(dropout_p, return_softmax) @@ -132,7 +130,25 @@ def varlen_fwd( return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state() -def fwd_kvcache(*args, **kwargs): +def fwd_kvcache( + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + rotary_cos, + rotary_sin, + cache_batch_idx, + block_table, + alibi_slopes, + out, + softmax_scale, + causal, + window_size_left, + window_size_right, + rotary_interleaved, + num_splits): pass @@ -157,12 +173,9 @@ def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, so print("gen_:", gen_) print("rng_state:", rng_state) - - if out is None: out = torch.empty_like(q) - # Ensure the tensors have requires_grad=True q.requires_grad_() k.requires_grad_() @@ -187,25 +200,20 @@ def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, so # Setup metadata if causal: metadata.need_causal() + # if bias is not None: # metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2]) return_softmax = True if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) + if dropout_p > 0.0: metadata.need_dropout(dropout_p, return_softmax) # Check arguments metadata.check_args(q, k, v, out) - - - # tri_out, _ = attention(q, k, v, out, metadata) - # tri_out.requires_grad_() - # dout.requires_grad_() - # tri_out.backward(dout) - # write your own version backward M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) # this passed from