You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi @vgoklani -- TE modules can be initialized under the with te.fp8_model_init(): context to allocate their primary weights in FP8 (as te.Float8Tensors) instead of allocating at a higher precision and maintaining separate FP8 buffers for compute.
I don't believe anyone has tried this in practice, but at least in principle, FSDP2's per-parameter sharding should work out-of-the-box with the torch.uint8 data underneath our te.Float8Tensors.
There are two things to be mindful of here:
You would not use the precompute_float8_dynamic_scale_for_fsdp(model) API from the linked example because TE already does this internally. You simply need to pass the process group for amax reductions (typically global/world group) into the te.fp8_autocast() context.
In the absence of native FP8 support in PyTorch, you cannot apply the optimizer step directly onto the FP8 model parameters. Consequently, te.fp8_model_init() is intended to be used with higher precision "master" copies of the model parameters in the optimizer.
If you experiment with TE + FSDP2, please share your experiences. We already support PyTorch's native FSDP but this involves TE modules carrying extra FP8 buffers for the compute while FSDP communication remains in higher precision. It would be great to extend our FSDP support to te.fp8_model_init() + FSDP2.
Adding to this, FSDP support should just be a matter of implementing fsdp_pre_all_gather and fsdp_post_all_gather methods in Float8Tensor, at least in principle.
FSDP2 supports all-gather using FP8:
https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323
Wondering if we could do this directly using TransformerEngine instead of torch-ao?
Thanks!
The text was updated successfully, but these errors were encountered: