-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable bfloat16 for micro sdpa kernel #2344
Conversation
15429bf
to
03109d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, Haleema!
03109d6
to
660916e
Compare
c6ea44d
to
26ef545
Compare
make test |
79d0565
to
486bfa0
Compare
make test |
7696807
to
ae0dc9b
Compare
#else // data type is bf16 | ||
#define VEC_TYPE2 ushort2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#else // data type is bf16 | |
#define VEC_TYPE2 ushort2 | |
#else if defined(QRY_DT_BF16) | |
#define VEC_TYPE2 ushort2 | |
#else | |
#error "Not supported data type" |
problem.Ta = problem.Tb = Type::f16; | ||
if (qry_md()->data_type == data_type::f16) { | ||
problem.Ta = problem.Tb = Type::f16; | ||
} else { // data_type is bf16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else if, and else for error/unimplemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its not needed since the init function will return if its not f16 or bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the next data type will start being supported, it would be more difficult to find this particular spot what goes wrong instead of catching nicely here. This is about being nice to others in the future, not about the current state in the present.
def_data_type(kernel_ctx, val_mdw.data_type(), "VAL"); | ||
def_data_type(kernel_ctx, dst_mdw.data_type(), "DST"); | ||
def_data_type(kernel_ctx, | ||
pd()->with_attn_mask() ? msk_mdw.data_type() : dnnl_f32, "MSK"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If no mask, shouldn't data type be undef to catch error earlier?
Maybe make a normal name "MASK", not sure I see the value of shorting one letter out...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel is written in a way that mask has to be defined in any case. Its unavoidable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind to put a comment on this regard next to this line for other developers in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can follow it up with since I still have to create a PR for enabling quantization for bf16.
@@ -46,6 +46,8 @@ | |||
--reset --dt=f32,bf16,f16 --in-shapes=0:32x16x128x64+1:32x16x128x64+5:32x16x128x128+8:32x16x128x64 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json | |||
--reset --dt=f32,bf16,f16 --in-shapes=0:acbd+1:acbd+8:acbd --case=complex_fusion/mha/sdpa-plain-simplified-f16.json | |||
--reset --dt=f32,bf16,f16 --in-shapes=3:384,3:384x384,3:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json | |||
--reset --dt=bf16 --in-shapes=0:1x1x16x64+1:1x1x16x64+8:1x1x16x64+5:1x1x1x16 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please file a tracker to support bigger shapes since it's benchdnn false-positive cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do. Somehow this commit got messed up and all my other commits were squashed.
0c7a877
to
ce712c9
Compare
Description
This PR focuses on enabling blfoat16 for micro_sdpa. It aims to make modification in the tile operations along with the micro sdpa kernel to enable this new data format.
Right now, this doesn't include quantization - this will be my next PR.
Large test cases have slight precision issues for some of the data points in the tensor - likely related to cumulative chain effect in accumulation. Talking to @dzarukin about it.
Test cases of smaller sizes pass - for example: