Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable bfloat16 for micro sdpa kernel #2344

Merged
merged 3 commits into from
Jan 11, 2025
Merged

Enable bfloat16 for micro sdpa kernel #2344

merged 3 commits into from
Jan 11, 2025

Conversation

h-sadia
Copy link
Contributor

@h-sadia h-sadia commented Jan 6, 2025

Description

This PR focuses on enabling blfoat16 for micro_sdpa. It aims to make modification in the tile operations along with the micro sdpa kernel to enable this new data format.

Right now, this doesn't include quantization - this will be my next PR.
Large test cases have slight precision issues for some of the data points in the tensor - likely related to cumulative chain effect in accumulation. Talking to @dzarukin about it.

Test cases of smaller sizes pass - for example:

> ./tests/benchdnn/benchdnn --engine=gpu --graph --dt=bf16 --in-shapes=0:1x1x16x64+1:1x1x16x64+8:1x1x16x64+5:1x1x1x16 -v7 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json                  [INFO] Graph dump:
> {(0) MatMul}
>     In: { (0):bf16:1x1x16x64, (1):bf16:1x1x16x64 }
>     Out: { (2):bf16:1x1x16x16 }
> {(1) Divide}
>     In: { (2):bf16:1x1x16x16, (3):bf16:1 }
>     Out: { (4):bf16:1x1x16x16 }
> {(2) Add}
>     In: { (4):bf16:1x1x16x16, (5):bf16:1x1x1x16 }
>     Out: { (6):bf16:1x1x16x16 }
> {(3) SoftMax}
>     In: { (6):bf16:1x1x16x16 }
>     Out: { (7):bf16:1x1x16x16 }
> {(4) MatMul}
>     In: { (7):bf16:1x1x16x16, (8):bf16:1x1x16x64 }
>     Out: { (9):bf16:1x1x16x64 }
> 
> run: --graph --engine=gpu --dt=bf16 --in-shapes=0:1x1x16x64+1:1x1x16x64+5:1x1x1x16+8:1x1x16x64 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
> 0:PASSED __REPRO: --graph --engine=gpu --dt=bf16 --in-shapes=0:1x1x16x64+1:1x1x16x64+5:1x1x1x16+8:1x1x16x64 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
> tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
./tests/benchdnn/benchdnn --engine=gpu --graph --dt=bf16 --in-shapes=0:1x1x16x32+1:1x1x16x32+9:1x1x16x32+4:1x1x1x1 -v7 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
[INFO] Graph dump:
{(3) MatMul}
   In: { (0):bf16:1x1x16x32, (1):bf16:1x1x16x32 }
   Out: { (2):bf16:1x1x16x16 }
{(6) Divide}
   In: { (2):bf16:1x1x16x16, (4):bf16:1x1x1x1 }
   Out: { (5):bf16:1x1x16x16 }
{(8) SoftMax}
   In: { (5):bf16:1x1x16x16 }
   Out: { (7):bf16:1x1x16x16 }
{(11) MatMul}
   In: { (7):bf16:1x1x16x16, (9):bf16:1x1x16x32 }
   Out: { (10):bf16:1x1x16x32 }

run: --graph --engine=gpu --dt=bf16 --in-shapes=0:1x1x16x32+1:1x1x16x32+4:1x1x1x1+9:1x1x16x32 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json

0:PASSED __REPRO: --graph --engine=gpu --dt=bf16 --in-shapes=0:1x1x16x32+1:1x1x16x32+4:1x1x1x1+9:1x1x16x32 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0

@h-sadia h-sadia requested a review from a team as a code owner January 6, 2025 21:45
@github-actions github-actions bot added the platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel label Jan 6, 2025
@h-sadia h-sadia force-pushed the hsadia/micro_sdpa_bf16 branch 6 times, most recently from 15429bf to 03109d6 Compare January 7, 2025 20:39
Copy link
Contributor

@petercad petercad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, Haleema!

@h-sadia h-sadia force-pushed the hsadia/micro_sdpa_bf16 branch from 03109d6 to 660916e Compare January 10, 2025 20:07
@h-sadia h-sadia changed the title WIP: Enable bfloat16 for micro sdpa kernel Enable bfloat16 for micro sdpa kernel Jan 10, 2025
@h-sadia h-sadia force-pushed the hsadia/micro_sdpa_bf16 branch 2 times, most recently from c6ea44d to 26ef545 Compare January 10, 2025 20:39
@h-sadia
Copy link
Contributor Author

h-sadia commented Jan 10, 2025

make test
disable device_cpu
disable benchdnn_all
disable run_gtests
enable benchdnn_graph

src/gpu/intel/ocl/micro_sdpa.hpp Outdated Show resolved Hide resolved
src/gpu/intel/ocl/micro_sdpa.cl Outdated Show resolved Hide resolved
src/gpu/intel/ocl/micro_sdpa.cl Outdated Show resolved Hide resolved
src/gpu/intel/ocl/tile_ops.h Outdated Show resolved Hide resolved
src/gpu/intel/ocl/tile_ops.h Outdated Show resolved Hide resolved
src/gpu/intel/ocl/tile_ops.h Outdated Show resolved Hide resolved
@h-sadia h-sadia force-pushed the hsadia/micro_sdpa_bf16 branch 8 times, most recently from 79d0565 to 486bfa0 Compare January 10, 2025 22:37
@h-sadia h-sadia requested review from a team as code owners January 10, 2025 22:44
@github-actions github-actions bot added the component:tests Codeowner: @oneapi-src/onednn-arch label Jan 10, 2025
@h-sadia
Copy link
Contributor Author

h-sadia commented Jan 10, 2025

make test
disable device_cpu
enable device_gpu
disable benchdnn_all
disable run_gtests
enable benchdnn_graph

@h-sadia h-sadia force-pushed the hsadia/micro_sdpa_bf16 branch from 7696807 to ae0dc9b Compare January 10, 2025 23:25
Comment on lines +42 to +43
#else // data type is bf16
#define VEC_TYPE2 ushort2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#else // data type is bf16
#define VEC_TYPE2 ushort2
#else if defined(QRY_DT_BF16)
#define VEC_TYPE2 ushort2
#else
#error "Not supported data type"

problem.Ta = problem.Tb = Type::f16;
if (qry_md()->data_type == data_type::f16) {
problem.Ta = problem.Tb = Type::f16;
} else { // data_type is bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else if, and else for error/unimplemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not needed since the init function will return if its not f16 or bf16.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the next data type will start being supported, it would be more difficult to find this particular spot what goes wrong instead of catching nicely here. This is about being nice to others in the future, not about the current state in the present.

def_data_type(kernel_ctx, val_mdw.data_type(), "VAL");
def_data_type(kernel_ctx, dst_mdw.data_type(), "DST");
def_data_type(kernel_ctx,
pd()->with_attn_mask() ? msk_mdw.data_type() : dnnl_f32, "MSK");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no mask, shouldn't data type be undef to catch error earlier?

Maybe make a normal name "MASK", not sure I see the value of shorting one letter out...

Copy link
Contributor Author

@h-sadia h-sadia Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel is written in a way that mask has to be defined in any case. Its unavoidable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind to put a comment on this regard next to this line for other developers in a separate PR?

Copy link
Contributor Author

@h-sadia h-sadia Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can follow it up with since I still have to create a PR for enabling quantization for bf16.

@@ -46,6 +46,8 @@
--reset --dt=f32,bf16,f16 --in-shapes=0:32x16x128x64+1:32x16x128x64+5:32x16x128x128+8:32x16x128x64 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:acbd+1:acbd+8:acbd --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=3:384,3:384x384,3:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
--reset --dt=bf16 --in-shapes=0:1x1x16x64+1:1x1x16x64+8:1x1x16x64+5:1x1x1x16 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please file a tracker to support bigger shapes since it's benchdnn false-positive cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do. Somehow this commit got messed up and all my other commits were squashed.

@h-sadia h-sadia force-pushed the hsadia/micro_sdpa_bf16 branch from 0c7a877 to ce712c9 Compare January 11, 2025 00:27
@github-actions github-actions bot removed the component:tests Codeowner: @oneapi-src/onednn-arch label Jan 11, 2025
@h-sadia h-sadia merged commit f145cbe into main Jan 11, 2025
3 checks passed
@h-sadia h-sadia deleted the hsadia/micro_sdpa_bf16 branch January 11, 2025 00:32
@h-sadia h-sadia restored the hsadia/micro_sdpa_bf16 branch January 11, 2025 00:35
h-sadia added a commit that referenced this pull request Jan 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants