Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it any plan for overlap and fuse GEMM for Jax? #320

Closed
MoFHeka opened this issue Jul 14, 2023 · 6 comments
Closed

Is it any plan for overlap and fuse GEMM for Jax? #320

MoFHeka opened this issue Jul 14, 2023 · 6 comments
Labels

Comments

@MoFHeka
Copy link

MoFHeka commented Jul 14, 2023

By combining with Alpa?

@nouiz
Copy link
Collaborator

nouiz commented Jul 14, 2023

TE project doesn't plan to add more parallelism then what each frameworks already support.
JAX already support many parallelism.
PAXml also add some support for pipelining parallelism.
There is no plan to merge Alpa (or other similar project) with TE.

TE try to support as many parallelism it can. We are actively working on that for JAX.
You can also look at Rosetta, our supported model: https://github.com/NVIDIA/JAX-Toolbox/
It is small right now, but it will grow.

@MoFHeka
Copy link
Author

MoFHeka commented Jul 21, 2023

A few days ago I asked core developers of Alpa——Yonghao Zhuang and Hao Zhang, and they told me that compiler optimization for DL model is no longer suitable for this era, and asked me to ask Nvidia. In fact, for example, alpa's pipeline parallelism is difficult to integrate with Jax, and sharding constraint in Jax using to support sequence parallelism is difficult to integrate with Jax. Their final recommendation was not to use Alpa/Jax without TPU.

@MoFHeka
Copy link
Author

MoFHeka commented Jul 21, 2023

But if we just use Megatron framework, It has a lot of limitations. So it's there any roadmap for more framework supporting, and more technique? Such as https://arxiv.org/abs/2105.05720.

@nouiz nouiz added the jax label Jan 27, 2024
@nouiz
Copy link
Collaborator

nouiz commented Jan 27, 2024

Quick update, we are adding SP in this PR:
#602

We changed at the end of last year how TE try to parallelize. It was using xmap(so hardcoding some cases), not it use custom_partitioning. So now, all TE operations should ack as native XLA operations and should respect uses of with_sharding_constraint().

This way, end users should be able to trigger all SPMD parallelism only by setting the input/output sharding or by adding with_sharding_constraint() at the right place.

The PR above, make it even simpler for SP.
I'll close this issue. If you don't think we have the proper building blocks, re-open it.
If you have specific request, open new ones.

Note, for the computation/communication overlap, this is works that is started in XLA. TE/JAX can't control that. There is some XLA_FLAGS that allow to enable more or play with some configuration options. Models in JAX-Toolbox use some of them for speed up. We are hoping to enable more of those cases by default over the year.

@MoFHeka
Copy link
Author

MoFHeka commented Jan 28, 2024

@nouiz I have seen this update, thank you for your work.
Another thing, do you have any plans to optimize the layer kernel of Praxis? JAX(Flax or Praxis) attention layers was constructed by Einsum kernels, which couldn't' be lowered to cudnn GEMM and the latest cudnn XLA FMHA kernel. When running attention layers in GPU, it could be only transformed to triton kernel...
TE currently only supports a limited number of transformer models (such as MOE is difficult to support) and does not yet support LORA SFT. So it may be necessary to optimize the layer composition of the Jax ecosystem.

@MoFHeka
Copy link
Author

MoFHeka commented Jan 29, 2024

The last question was same as NVIDIA/JAX-Toolbox#502

@MoFHeka MoFHeka closed this as completed Jan 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants