Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does Float8Linear support Tensor Parallelism and Sequence Parallelism? #1198

Open
zigzagcai opened this issue Oct 30, 2024 · 1 comment
Open

Comments

@zigzagcai
Copy link

zigzagcai commented Oct 30, 2024

We know that Transformer_Engine has support for FP8 training with data parallel + tensor parallel + sequence parallel, https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/advanced_optimizations.html#Multi-GPU-training

However, when I tried to check with the source code of swap_linear_modules and Float8Linear, and the documentations/discussion of torchao, I can only see the support for FP8 with FSDP (as far as I know).

So, does torchao also has the support for tensor parallelism and sequence parallelism with FP8Linear?

Thanks!

@vkuzo
Copy link
Contributor

vkuzo commented Oct 30, 2024

Hi @zigzagcai , we do support TP and SP implemented via DTensor (https://pytorch.org/docs/stable/distributed.tensor.parallel.html). We have designed the float8 APIs to be orthogonal to TP/SP, so for the user the workflow would be:

  1. start with high precision model
  2. convert parts of model to float8 (torchao.float8)
  3. define distributed strategy for model (https://pytorch.org/docs/stable/distributed.tensor.parallel.html)

If you are ok with distributed communications to happen in high precision, then (2) and (3) are independent. If you are interested in doing the all-gathers in low precision, then (2) and (3) interact to coordinate on low precision casting.

Here is an e2e example of (2) and (3) interacting with all-gathers happening in low precision:

def _test_fp8_mlp_tensor_parallelism_base(
. If you were to take that example and replace Float8ColwiseParallel with ColwiseParallel, etc, then you'd get to a point where (2) and (3) are independent.

A full e2e training example of how all of this fits together is https://github.com/pytorch/torchtitan.

Let me keep this issue open to track adding more information to the README about how this works. Thanks for the question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants