Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Need Matmul Attention layer instead of Einsum to support GPU running #502

Open
MoFHeka opened this issue Jan 27, 2024 · 4 comments

Comments

@MoFHeka
Copy link

MoFHeka commented Jan 27, 2024

Einsum kernel in Praxis couldn't' be lowered to cudnn GEMM. The computing performance is seriously affected. Jax version Attention layer much slower than Tensorflow version.

@nouiz
Copy link
Collaborator

nouiz commented Jan 27, 2024

Why do you say cudnn GEMM isn't used? Normally it should. Can you provide an example where cudnn gemm isn't used in such a case?
And why did you do this here and not in the TransformerEngine repo?
Did I miss-understood the request?

@MoFHeka
Copy link
Author

MoFHeka commented Jan 28, 2024

@nouiz Yes, TransformerEngine did use cudnn GEMM.

But JAX(Flax or Praxis) attention layers was constructed by Einsum kernels, which couldn't' be lowered to cudnn GEMM and the latest cudnn XLA FMHA kernel. When running attention layers in GPU, it could be only transformed to triton kernel according the XLA dump log...

TE currently only supports a limited number of transformer models (such as MOE is difficult to support) and does not yet support LORA SFT. So it may be necessary to optimize the layer composition of the Jax ecosystem.

Sorry, I'm not sure where to put the requirement because it doesn't look like the TE team should be responsible for it. As I understand it, the TE team is only responsible for the 'custom_call' part of jax.

BossHi-20240128172529

@nouiz
Copy link
Collaborator

nouiz commented Jan 29, 2024

I think you will be interested by this PR: jax-ml/jax#18814

@MoFHeka
Copy link
Author

MoFHeka commented Jan 29, 2024

@nouiz Cool, thank you!
It would be nice if someone could also change the code in jax wrapper components like praxis, flax, etc., since they are now written in einsum.
Because Jax-Toolbox use paxml, and paxml use praxis. But praxis was written without matmul regardless what kernel generation in Jax core.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants