You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
However, with simpleFSDP with full model compile, torch.compile only saves the SDPA output. This means that FSDP2 saves an extra (bs, seq_len, dim) tensor per transformer block.
Traditionally, SDPA output is required for SDPA backward, and the input to wo is required for the wo backward. However, it may be profitable memory-wise to recompute one from the other (e.g. recompute SDPA output from undo-ing the transpose of wo input).
One question is why the activations saved for backward differ between simple FSDP with full model compile vs. FSDP2 with transformer block compile.
The text was updated successfully, but these errors were encountered:
With FSDP2 and transformer block compile,
torch.compile
saves both the SDPA output and the contiguous transposed tensor for backward:torchtitan/torchtitan/models/llama/model.py
Lines 210 to 213 in 7e93822
However, with simpleFSDP with full model compile,
torch.compile
only saves the SDPA output. This means that FSDP2 saves an extra(bs, seq_len, dim)
tensor per transformer block.Traditionally, SDPA output is required for SDPA backward, and the input to
wo
is required for thewo
backward. However, it may be profitable memory-wise to recompute one from the other (e.g. recompute SDPA output from undo-ing the transpose ofwo
input).One question is why the activations saved for backward differ between simple FSDP with full model compile vs. FSDP2 with transformer block compile.
The text was updated successfully, but these errors were encountered: