Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FP32 Accumulate in Flash Attention and Flash Decode #13364

Open
2 of 4 tasks
caixunshiren opened this issue Oct 2, 2024 · 1 comment
Open
2 of 4 tasks

Enable FP32 Accumulate in Flash Attention and Flash Decode #13364

caixunshiren opened this issue Oct 2, 2024 · 1 comment
Assignees
Labels
flash-attention flash-decode kernels kernels, such as hlks or llks or below llama3 models Models that run in tt-metal P1

Comments

@caixunshiren
Copy link
Contributor

caixunshiren commented Oct 2, 2024

Description

We do not have support for fp32 accumulate in sdpa family kernels. This becomes a problem when number of chunks gets large and we see diverging pcc from ground truth. For models that requires 128K sequel, this is problematic.

This issue tracks the enabling of fp32 accumulate in the following kernels:

round 1:

  • sdpa (bf16 cbs, fp32 accum)
  • sdpa decode (bf16 cbs, fp32 accum)

round 2:

  • sdpa (fp32 cbs, fp32 accum)
  • sdpa decode (fp32 cbs, fp32 accum)

FYI @cglagovichTT

@caixunshiren
Copy link
Contributor Author

Update:

  • It appears that the largest amount of pcc drop are attributes to math approximation, not fp32 accumulate. The issue is tracked here: Fix Diverging PCC issue in SDPA Kernels #13866
  • Based on my experiments, there is no issue of fp32 accumulate with bf16 cbs. With fp32 cbs, we need some inputs to mul_block_inplace and add_block_inplace to be bf16 cbs during im and stat accumulation/updates, otherwise we get pcc degradation compared to bf16. My WIP work is on branch sdpa-fp32-investigations
  • I also found out that reconfig_dataformat doesn't work as expected and could hang/wrong result for in/out_cb. This explains the flash decode fp32 accumulate hang in reducer cores that I saw earlier: Deterministic Hang with unpack_reconfig_data_format at a specific spot in Flash Decode Kernel #9608
  • Actions for now: Support fp32 accumulate with bf16 cbs on sdpa/sdpa-decode kernels. Leave fp32 cbs for future.

ct-clmsn pushed a commit to ct-clmsn/tt-metal that referenced this issue Nov 12, 2024
ct-clmsn pushed a commit to ct-clmsn/tt-metal that referenced this issue Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flash-attention flash-decode kernels kernels, such as hlks or llks or below llama3 models Models that run in tt-metal P1
Projects
None yet
Development

No branches or pull requests

1 participant