Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic Hang with unpack_reconfig_data_format at a specific spot in Flash Decode Kernel #9608

Closed
Tracked by #12330
caixunshiren opened this issue Jun 21, 2024 · 6 comments
Assignees
Labels
bug Something isn't working kernels kernels, such as hlks or llks or below LLK LLM_bug

Comments

@caixunshiren
Copy link
Contributor

Description

In the flash decode op #9510 which I'm currently implementing, I'm seeing that when unpack_reconfig_data_format is called before a mul_tiles_bcast_xxx kernel. As shown below, uncommenting line a would cause a hang, but uncommenting line b would not. It is worth noting that all inputs and intermediates are bf16, and cbs involved are only used in compute.

  line a     // unpack_reconfig_data_format_srca(cb_out_accumulate_im_2); // DEBUG
                mul_block_bcast_cols_inplace(cb_out_accumulate_im_2, cb_exp_max_diff_2, Sq_chunk_t, DHt);
                /// O_2 = torch.matmul(torch.eye(padded_num_heads) * torch.exp(m_1 - m), O_1)
  line b     // unpack_reconfig_data_format(cb_out_accumulate_im, cb_exp_max_diff); // DEBUG
                // pack_reconfig_data_format(cb_out_accumulate_im);
                mul_block_bcast_cols_inplace(cb_out_accumulate_im, cb_exp_max_diff, Sq_chunk_t, DHt);

unpack_reconfig_data_format_srca(cb_out_accumulate_im_2); // DEBUG

Things we tried

  • adding TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::UNPACK) and TTI_STALLWAIT(p_stall::STALL_UNPACK, p_stall::TRISC_CFG) before and after line a; Did not resolve the hang.
  • adding TTI_STALLWAIT(p_stall::STALL_UNPACK, p_stall::UNPACK) before and after line a; Did not resolve the hang.
  • commented out all the compute, and narrowed down to the mul_block_bcast_cols_inplace function, which commenting out the compute call mul_tiles_bcast_scalar resolves the hang.

To Repro

I pushed a debug branch which is rebased on latest main from June 20th:
https://github.com/tenstorrent/tt-metal/tree/xuncai/flash-decode-reconfig-dataformat-hang
After building the branch, (optional) enable dprint:

export TT_METAL_DPRINT_CORES=0,0

Then, run the following command:

pytest -svv tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py

You should see that [C] R ckpt 3 printed 3 times but [C] R ckpt 3.1 does not print.
If you comment out line 629 in sdpa_flash_decode.cpp, then the test is expected to not hang.

unpack_reconfig_data_format_srca(cb_out_accumulate_im_2); // DEBUG

FYI @cglagovichTT

@caixunshiren
Copy link
Contributor Author

An update to this issue:

I found out in #13364 that input/output cbs just do not support reconfig dataformat properly. You would get hang for calling unpack_reconfig_data_format on out_cbs and get degrading pcc for calling pack_reconfig_data_format on in_cbs.

@ttmtrajkovic @rtawfik01 are you aware of this issue? It seems like this is a software design choice, not a hardware constraint -- I'd like to file an issue to have it supported. Do you know who I should talk to?

@ttmtrajkovic
Copy link
Contributor

Adding @rdjogoTT.
he has recently worked on refactoring reconfig functions. @rdjogoTT, could you please review this feedback?

@rdjogoTT
Copy link
Contributor

An update to this issue:

I found out in #13364 that input/output cbs just do not support reconfig dataformat properly. You would get hang for calling unpack_reconfig_data_format on out_cbs and get degrading pcc for calling pack_reconfig_data_format on in_cbs.

@ttmtrajkovic @rtawfik01 are you aware of this issue? It seems like this is a software design choice, not a hardware constraint -- I'd like to file an issue to have it supported. Do you know who I should talk to?

@caixunshiren @ttmtrajkovic
This issue is caused by how the JIT build generates unpack_src_format, unpack_dst_format, pack_src_format, and pack_dst_format.

The unpack_dst_format is set to DataFormat::Invalid for output CBs (16-23), but is used in reconfig_data_format (what used to be unpack_reconfig_data_format before my recent changes).

Similarly pack_src_format is set to DataFormat::Invalid for input and param CBs (0-15), but is needed in pack_reconfig_data_format.

Right now, only intermediate CBs can be used for both input and output within compute kernels.

@rdjogoTT
Copy link
Contributor

This should no longer be a problem with #14971.
The restriction for CB types was removed with that PR.
Can you try to repro again?

@prajaramanTT
Copy link

@caixunshiren Is this still an open issue ? If not, can you please mark this closed ? Thanks.

@caixunshiren
Copy link
Contributor Author

@prajaramanTT no it's no longer open. Closing this issue now. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working kernels kernels, such as hlks or llks or below LLK LLM_bug
Projects
None yet
Development

No branches or pull requests

5 participants