-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
custom_kernel: fix shape mismatch by sharding segment_ids in flash attn. #8333
base: master
Are you sure you want to change the base?
custom_kernel: fix shape mismatch by sharding segment_ids in flash attn. #8333
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Looks like your use case runs into computational inaccuracies which may suggest enable_manual_sharding
API calls need correction. I suggest adding a test case to further verify / debug the issue in a small example.
Here is a reference for kernel tests you can refer to.
4ba9067
to
0608900
Compare
@miladm Hi! We have added a test that currently fails (16% of the values are correct, the others are not). I hope it will help you understand whats wrong. |
@dudulightricks Thanks for submitting the test code. We had a review of your code internally. It seems if you shard the KV and Q segment_id's the code won't attend the query to all kv elements in the matmul - hence the numerical inconsistency. Have you tried sharding the query segment_id only? |
@miladm I just did and the test still fails, but why would something like this happen anyway? We are sharding the model and the data and expect consistency in the results in any sharding case. Can't we trust the result in any sharding case? |
when adding the sharding support in this module, seqment_ids weren't take into count which causes a failure with shape mismatch when using them in sharded flash attention.
0608900
to
b5d1b8f
Compare
Description: This PR addresses an issue where segment_ids were not considered when adding sharding support in this module. The absence of segment_ids handling results in a shape mismatch failure when using them in sharded Flash Attention.
Edit:
During training with dummy data using this fix, the loss stalls at 0.2 and does not converge to 0 as expected. Further adjustments are needed to resolve this convergence issue.