Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP: How to do all-gather using FP8? #1188

Open
vgoklani opened this issue Sep 17, 2024 · 2 comments
Open

FSDP: How to do all-gather using FP8? #1188

vgoklani opened this issue Sep 17, 2024 · 2 comments

Comments

@vgoklani
Copy link

FSDP2 supports all-gather using FP8:

https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323

Wondering if we could do this directly using TransformerEngine instead of torch-ao?

Thanks!

@denera
Copy link
Collaborator

denera commented Sep 17, 2024

Hi @vgoklani -- TE modules can be initialized under the with te.fp8_model_init(): context to allocate their primary weights in FP8 (as te.Float8Tensors) instead of allocating at a higher precision and maintaining separate FP8 buffers for compute.

I don't believe anyone has tried this in practice, but at least in principle, FSDP2's per-parameter sharding should work out-of-the-box with the torch.uint8 data underneath our te.Float8Tensors.

There are two things to be mindful of here:

  1. You would not use the precompute_float8_dynamic_scale_for_fsdp(model) API from the linked example because TE already does this internally. You simply need to pass the process group for amax reductions (typically global/world group) into the te.fp8_autocast() context.
  2. In the absence of native FP8 support in PyTorch, you cannot apply the optimizer step directly onto the FP8 model parameters. Consequently, te.fp8_model_init() is intended to be used with higher precision "master" copies of the model parameters in the optimizer.

If you experiment with TE + FSDP2, please share your experiences. We already support PyTorch's native FSDP but this involves TE modules carrying extra FP8 buffers for the compute while FSDP communication remains in higher precision. It would be great to extend our FSDP support to te.fp8_model_init() + FSDP2.

@timmoon10
Copy link
Collaborator

Adding to this, FSDP support should just be a matter of implementing fsdp_pre_all_gather and fsdp_post_all_gather methods in Float8Tensor, at least in principle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants