Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable fwd and varlen_fwd on AMD #60

Merged
merged 4 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ jobs:
python setup.py install
- name: Test
run: |
pytest tests/test_flash_attn.py::test_flash_attn_output
pytest tests/test_flash_attn.py::test_flash_attn_output
pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
76 changes: 65 additions & 11 deletions flash_attn/flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,72 @@ def fwd(q,

return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state()

def varlen_fwd(
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
block_table_,
alibi_slopes,\
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors,
causal,
window_size_left,
window_size_right,
return_softmax,
gen_):

print("flash_attn_triton_amd.py::varlen_fwd")
micmelesse marked this conversation as resolved.
Show resolved Hide resolved
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on HIP")



if o is None:
o = torch.empty_like(q)



# create metadata object
input_metadata = MetaData(sm_scale=softmax_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)

# get shapes
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, input_metadata)

# Setup metadata
if causal:
input_metadata.need_causal()
# if bias is not None:
# metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2])
if alibi_slopes is not None:
input_metadata.need_alibi(alibi_slopes, batch, nheads_q)
if dropout_p > 0.0:
input_metadata.need_dropout(dropout_p, return_softmax)

# Check arguments
input_metadata.check_args(q, k, v, o)

# Perform the forward attention computation
tri_out, encoded_softmax = attention(q, k, v, o, input_metadata)

softmax_lse = encoded_softmax
softmax_p = encoded_softmax

return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state()

def fwd_kvcache(*args, **kwargs):
pass


def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left,
Expand Down Expand Up @@ -1729,21 +1794,10 @@ def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, so
return dq, dk, dv, None





def varlen_fwd(q, k, v, *args, **kwargs):
pass



def varlen_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, *args, **kwargs):
pass


def fwd_kvcache(*args, **kwargs):
pass



# /////////////////////////////////////////// CLI //////////////////////////////////////////////////////////
Expand Down
11 changes: 10 additions & 1 deletion tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,8 @@ def test_flash_attn_output(

# skip all cases where seqlen_q, seqlen_k, or d are not powers of 2
if not (is_power_of_2(seqlen_q) and is_power_of_2(seqlen_k) and is_power_of_2(d)):
pytest.skip("seqlen_q, seqlen_k, or d are not powers of 2")
pytest.skip("seqlen_q, seqlen_k, or d are not powers of 2")

if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down Expand Up @@ -1181,6 +1182,14 @@ def test_flash_attn_output(
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, forward_only=True
):
if is_hip():
if dropout_p != 0.0:
pytest.skip("Dropout not supported in HIP")

# skip all cases where seqlen_q, seqlen_k, or d are not powers of 2
if not (is_power_of_2(seqlen_q) and is_power_of_2(seqlen_k) and is_power_of_2(d)):
micmelesse marked this conversation as resolved.
Show resolved Hide resolved
pytest.skip("seqlen_q, seqlen_k, or d are not powers of 2")

if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down
Loading